{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we'll be comparing our Transformer XL architecture with the official implementation.\n",
    "Note: this is a non-refactored, dirty notebook that shouldn't be used as a reference for implementation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Downloading the reference code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Transformer XL reference repo exists\n"
     ]
    }
   ],
   "source": [
    "%%bash\n",
    "if [ -d \"./transformer-xl\" ] \n",
    "then\n",
    "    echo \"Transformer XL reference repo exists\" \n",
    "else\n",
    "    echo \"Cloning Transformer XL repo\" \n",
    "    git clone https://github.com/kimiyoung/transformer-xl.git\n",
    "fi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll be using the penn treebank dataset to benchmark our model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "DATASET = \"penn\"\n",
    "DATA_DIR = Path(\"../data\") / DATASET"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "#scrap\n",
    "import sys\n",
    "from pathlib import Path \n",
    "\n",
    "DATASET = \"penn\"\n",
    "REF_PATH = Path(\"./transformer-xl\")\n",
    "DATA_DIR = Path(\"..\") / \"data\" / DATASET\n",
    "\n",
    "sys.path.append(str(REF_PATH / \"pytorch\"))\n",
    "sys.path.append(str(REF_PATH / \"pytorch\" / \"utils\"))\n",
    "\n",
    "TESTING = not IS_KAGGLE_KERNEL # Keep True for now"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "#scrap\n",
    "from mem_transformer import RelPartialLearnableMultiHeadAttn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter, OrderedDict\n",
    "#TODO: Clean up\n",
    "import torch\n",
    "\n",
    "class Vocab(object):\n",
    "    def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True,\n",
    "                 delimiter=None, vocab_file=None):\n",
    "        self.counter = Counter()\n",
    "        self.special = special\n",
    "        self.min_freq = min_freq\n",
    "        self.max_size = max_size\n",
    "        self.lower_case = lower_case\n",
    "        self.delimiter = delimiter\n",
    "        self.vocab_file = vocab_file\n",
    "\n",
    "    def tokenize(self, line, add_eos=False, add_double_eos=False):\n",
    "        line = line.strip()\n",
    "        # convert to lower case\n",
    "        if self.lower_case:\n",
    "            line = line.lower()\n",
    "\n",
    "        # empty delimiter '' will evaluate False\n",
    "        if self.delimiter == '':\n",
    "            symbols = line\n",
    "        else:\n",
    "            symbols = line.split(self.delimiter)\n",
    "\n",
    "        if add_double_eos: # lm1b\n",
    "            return ['<S>'] + symbols + ['<S>']\n",
    "        elif add_eos:\n",
    "            return symbols + ['<eos>']\n",
    "        else:\n",
    "            return symbols\n",
    "\n",
    "    def count_file(self, path, verbose=False, add_eos=False):\n",
    "        if verbose: print('counting file {} ...'.format(path))\n",
    "        assert os.path.exists(path)\n",
    "\n",
    "        sents = []\n",
    "        with open(path, 'r', encoding='utf-8') as f:\n",
    "            for idx, line in enumerate(f):\n",
    "                if verbose and idx > 0 and idx % 500000 == 0:\n",
    "                    print('    line {}'.format(idx))\n",
    "                symbols = self.tokenize(line, add_eos=add_eos)\n",
    "                self.counter.update(symbols)\n",
    "                sents.append(symbols)\n",
    "\n",
    "        return sents\n",
    "\n",
    "    def count_sents(self, sents, verbose=False):\n",
    "        \"\"\"\n",
    "            sents : a list of sentences, each a list of tokenized symbols\n",
    "        \"\"\"\n",
    "        if verbose: print('counting {} sents ...'.format(len(sents)))\n",
    "        for idx, symbols in enumerate(sents):\n",
    "            if verbose and idx > 0 and idx % 500000 == 0:\n",
    "                print('    line {}'.format(idx))\n",
    "            self.counter.update(symbols)\n",
    "\n",
    "    def _build_from_file(self, vocab_file):\n",
    "        self.idx2sym = []\n",
    "        self.sym2idx = OrderedDict()\n",
    "\n",
    "        with open(vocab_file, 'r', encoding='utf-8') as f:\n",
    "            for line in f:\n",
    "                symb = line.strip().split()[0]\n",
    "                self.add_symbol(symb)\n",
    "        self.unk_idx = self.sym2idx['<UNK>']\n",
    "\n",
    "    def build_vocab(self):\n",
    "        if self.vocab_file:\n",
    "            print('building vocab from {}'.format(self.vocab_file))\n",
    "            self._build_from_file(self.vocab_file)\n",
    "            print('final vocab size {}'.format(len(self)))\n",
    "        else:\n",
    "            print('building vocab with min_freq={}, max_size={}'.format(\n",
    "                self.min_freq, self.max_size))\n",
    "            self.idx2sym = []\n",
    "            self.sym2idx = OrderedDict()\n",
    "\n",
    "            for sym in self.special:\n",
    "                self.add_special(sym)\n",
    "\n",
    "            for sym, cnt in self.counter.most_common(self.max_size):\n",
    "                if cnt < self.min_freq: break\n",
    "                self.add_symbol(sym)\n",
    "\n",
    "            print('final vocab size {} from {} unique tokens'.format(\n",
    "                len(self), len(self.counter)))\n",
    "\n",
    "    def encode_file(self, path, ordered=False, verbose=False, add_eos=True,\n",
    "            add_double_eos=False):\n",
    "        if verbose: print('encoding file {} ...'.format(path))\n",
    "        assert os.path.exists(path)\n",
    "        encoded = []\n",
    "        with open(path, 'r', encoding='utf-8') as f:\n",
    "            for idx, line in enumerate(f):\n",
    "                if verbose and idx > 0 and idx % 500000 == 0:\n",
    "                    print('    line {}'.format(idx))\n",
    "                symbols = self.tokenize(line, add_eos=add_eos,\n",
    "                    add_double_eos=add_double_eos)\n",
    "                encoded.append(self.convert_to_tensor(symbols))\n",
    "\n",
    "        if ordered:\n",
    "            encoded = torch.cat(encoded)\n",
    "\n",
    "        return encoded\n",
    "\n",
    "    def encode_sents(self, sents, ordered=False, verbose=False):\n",
    "        if verbose: print('encoding {} sents ...'.format(len(sents)))\n",
    "        encoded = []\n",
    "        for idx, symbols in enumerate(sents):\n",
    "            if verbose and idx > 0 and idx % 500000 == 0:\n",
    "                print('    line {}'.format(idx))\n",
    "            encoded.append(self.convert_to_tensor(symbols))\n",
    "\n",
    "        if ordered:\n",
    "            encoded = torch.cat(encoded)\n",
    "\n",
    "        return encoded\n",
    "\n",
    "    def add_special(self, sym):\n",
    "        if sym not in self.sym2idx:\n",
    "            self.idx2sym.append(sym)\n",
    "            self.sym2idx[sym] = len(self.idx2sym) - 1\n",
    "            setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym])\n",
    "\n",
    "    def add_symbol(self, sym):\n",
    "        if sym not in self.sym2idx:\n",
    "            self.idx2sym.append(sym)\n",
    "            self.sym2idx[sym] = len(self.idx2sym) - 1\n",
    "\n",
    "    def get_sym(self, idx):\n",
    "        assert 0 <= idx < len(self), 'Index {} out of range'.format(idx)\n",
    "        return self.idx2sym[idx]\n",
    "\n",
    "    def get_idx(self, sym):\n",
    "        if sym in self.sym2idx:\n",
    "            return self.sym2idx[sym]\n",
    "        else:\n",
    "            # print('encounter unk {}'.format(sym))\n",
    "            assert '<eos>' not in sym\n",
    "            assert hasattr(self, 'unk_idx')\n",
    "            return self.sym2idx.get(sym, self.unk_idx)\n",
    "\n",
    "    def get_symbols(self, indices):\n",
    "        return [self.get_sym(idx) for idx in indices]\n",
    "\n",
    "    def get_indices(self, symbols):\n",
    "        return [self.get_idx(sym) for sym in symbols]\n",
    "\n",
    "    def convert_to_tensor(self, symbols):\n",
    "        return torch.LongTensor(self.get_indices(symbols))\n",
    "\n",
    "    def convert_to_sent(self, indices, exclude=None):\n",
    "        if exclude is None:\n",
    "            return ' '.join([self.get_sym(idx) for idx in indices])\n",
    "        else:\n",
    "            return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude])\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.idx2sym)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "#scrap\n",
    "from vocabulary import Vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x11da2a7f0>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#scrap\n",
    "torch.manual_seed(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cpu\") if not torch.cuda.is_available() else torch.device(\"cuda:0\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Overview"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Attention"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's start off simple by imagining some word embeddings of shape `(seq=7, batch_size=3, embedding_dim=32)`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq, batch_size, embedding_dim = 7, 3, 32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "word_embs = torch.rand(seq, batch_size, embedding_dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the Transformer XL, we also feed the cached outputs of the model for the previous sequence. In this case, we would be feeding the word embeddings from the previous sequence as additional input to our model.\n",
    "\n",
    "To make things clearer, let's imagine our previous sequence was of length `prev_seq=6`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "memory = torch.rand(6, 3, 32) # hidden states from the previous "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Relative positional embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are two sources of attention: the content and position"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### MHA: The core component"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Aggregating all the above, we get the following MultiHeadAttention module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import *\n",
    "\n",
    "class MultiHeadAttention(nn.Module):\n",
    "    def __init__(self, d_input: int, d_inner: int, n_heads: int=4, \n",
    "                 dropout: float=0.1, dropouta: float=0.):\n",
    "        super().__init__()\n",
    "        self.d_input = d_input\n",
    "        self.d_inner = d_inner\n",
    "        self.n_heads = n_heads\n",
    "        # this layer applies the linear transformation required\n",
    "        # for the keys and values for all heads at once for efficiency\n",
    "        self.linear_kv = nn.Linear(\n",
    "            d_input, \n",
    "            (d_inner * n_heads * 2), # 2 is for keys and values\n",
    "            bias=False, # we don't apply bias, making this a simple matrix multiplication\n",
    "        )\n",
    "        # for queries (will not be concatenated with memorized states so separate)\n",
    "        self.linear_q = nn.Linear(\n",
    "            d_input, d_inner * n_heads,\n",
    "            bias=False\n",
    "        )\n",
    "        # for positional embeddings\n",
    "        self.linear_p = nn.Linear(\n",
    "            d_input, d_inner * n_heads,\n",
    "            bias=False\n",
    "        )\n",
    "        self.scale = 1 / (d_inner ** 0.5) # for scaled dot product attention\n",
    "        self.dropa = nn.Dropout(dropouta)\n",
    "        # we will use this to project back to the input dimension\n",
    "        self.lout = nn.Linear(self.d_inner * self.n_heads, self.d_input, bias=False)\n",
    "        self.norm = nn.LayerNorm(self.d_input)\n",
    "        self.dropo = nn.Dropout(dropout)\n",
    "        \n",
    "    def _rel_shift(self, x):\n",
    "        # TODO: Understand\n",
    "        zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),\n",
    "                               device=x.device, dtype=x.dtype)\n",
    "        x_padded = torch.cat([zero_pad, x], dim=1)\n",
    "\n",
    "        x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])\n",
    "\n",
    "        x = x_padded[1:].view_as(x)\n",
    "        return x\n",
    "        \n",
    "    def forward(self, input_: torch.FloatTensor, # (cur_seq, b, d_in)\n",
    "                pos_embs: torch.FloatTensor, # (cur_seq + prev_seq, d_in)\n",
    "                memory: torch.FloatTensor, # (prev_seq, b, d_in)\n",
    "                u: torch.FloatTensor, # (H, d)\n",
    "                v: torch.FloatTensor, # (H, d)\n",
    "                mask: Optional[torch.FloatTensor]=None,\n",
    "        ):\n",
    "        \"\"\"\n",
    "        pos_embs: we pass the positional embeddings in separately\n",
    "            because we need to handle relative positions\n",
    "        input shape: (seq, bs, self.d_input)\n",
    "        pos_embs shape: (seq + prev_seq, bs, self.d_input)\n",
    "        output shape: (seq, bs, self.d_input)\n",
    "        \"\"\"\n",
    "        cur_seq = input_.shape[0] #  sequence length of current segment\n",
    "        prev_seq = memory.shape[0] # sequence length of previous segment\n",
    "        H, d = self.n_heads, self.d_inner\n",
    "        input_with_memory = torch.cat([memory, input_], dim=0) # concatenate recurrent memory\n",
    "                                                               # across sequence dimension\n",
    "\n",
    "        # we will use the following symbols to represent the shape of the tensors\n",
    "        # cs: current sequence length, b: batch, H: number of heads\n",
    "        # d: inner dimension, ps: previous sequence length\n",
    "        # The key and value are now conditioned on the preceding context\n",
    "        k_tfmd, v_tfmd = \\\n",
    "            torch.chunk(self.linear_kv(input_with_memory), 2, dim=-1) # (cs + ps, b, H * d)\n",
    "        q_tfmd = self.linear_q(input_) # (cs, b, H * d)\n",
    "\n",
    "        # apply scaled dot product attention\n",
    "        # look at the following dimensions carefully, since this is the key operation\n",
    "        # in the Transformer/Transformer XL architecture\n",
    "        \n",
    "        _, bs, _ = q_tfmd.shape\n",
    "        assert bs == k_tfmd.shape[1]\n",
    "        # content-based attention term ((a) + (c) in the paper)\n",
    "        # this is the standard attention term in the original Transformer, except without positional embeddings\n",
    "        # which are handled separately in the Transformer XL (see below)\n",
    "        # here, i corresponds to the number of queries = number of current inputs/targets (seq-wise)\n",
    "        # j corresponds to the number of key/values = number of vectors that we can use to compute the \n",
    "        # vector for each query\n",
    "        content_attn = torch.einsum(\"ibhd,jbhd->ijbh\", (\n",
    "                (q_tfmd.view(cur_seq, bs, H, d) + # (a)\n",
    "                 u), # (c): u represents the global (independent of the query)\n",
    "                     # bias towards certain key/values = words\n",
    "                     # Note: maybe this could be a per-attention head parameter?\n",
    "                 k_tfmd.view(cur_seq + prev_seq, bs, H, d) # There is no positional information to be found here\n",
    "        )) # (cs, cs + ps, b, H)\n",
    "        \n",
    "        # position-based attention term ((b) + (d) in the paper)\n",
    "        # this attention is solely based on the position of the key/values\n",
    "        # (i.e. it does not take the content of the key/values into account)\n",
    "        p_tfmd = self.linear_p(pos_embs) # (cs + ps, b, H * d)\n",
    "        position_attn = torch.einsum(\"ibhd,jhd->ijbh\", (\n",
    "                (q_tfmd.view(cur_seq, bs, H, d) + # (b)\n",
    "                 v), # (d): v represents the global (independent of the query)\n",
    "                     # bias towards certain positions\n",
    "                 p_tfmd.view(cur_seq + prev_seq, H, d) # Notice there is not content information\n",
    "                                                        # regarding keys and values here!\n",
    "        )) # (cs, cs + ps, b, H)\n",
    "        \n",
    "        #  ???\n",
    "        position_attn = self._rel_shift(position_attn)\n",
    "        \n",
    "        # the attention is the sum of content-based and position-based attention\n",
    "        attn = content_attn + position_attn\n",
    "\n",
    "        if mask is not None and mask.any().item():\n",
    "            attn = attn.masked_fill(\n",
    "                mask[...,None], -float('inf'))\n",
    "        attn = torch.softmax(attn * self.scale, # rescale to prevent values from exploding\n",
    "                             dim=1) # normalize across the value sequence dimension\n",
    "        attn = self.dropa(attn)\n",
    "        \n",
    "        attn_weighted_values = (torch.einsum(\"ijbh,jbhd->ibhd\",\n",
    "                                           (attn, # (cs, cs + ps, b, H)\n",
    "                                            v_tfmd.view(cur_seq + prev_seq, bs, H, d), # (cs + ps, b, H, d)\n",
    "                                           )) # (cs, b, H, d)\n",
    "                                .contiguous() # we need to change the memory layout to make `view` work\n",
    "                                .view(cur_seq, bs, H * d)) # (cs, b, H * d)\n",
    "\n",
    "        # Project back to input dimension and add residual connection\n",
    "        output = input_ + self.dropo(self.lout(attn_weighted_values))\n",
    "        output = self.norm(output)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "hoge = torch.arange(4 * 5 * 3 * 1).view(4, 5, 3, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = hoge\n",
    "zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),\n",
    "                               device=x.device, dtype=x.dtype)\n",
    "x_padded = torch.cat([zero_pad, x], dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x11da2a7f0>"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#scrap\n",
    "torch.manual_seed(10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's test it out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "mha = MultiHeadAttention(32, 17, n_heads=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "inpt = torch.rand(7, 3, 32)\n",
    "pos = torch.rand(13, 32)\n",
    "mem = torch.rand(6, 3, 32)\n",
    "u, v = torch.rand(4, 17), torch.rand(4, 17)\n",
    "x1 = mha(inpt, pos, mem, u, v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([7, 3, 32])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x1.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.3604, -1.0046,  0.0095, -0.3961,  0.0042,  1.4597,  0.8075, -0.8257,\n",
       "         -2.1602, -0.6140, -0.5043, -0.5235, -1.1914, -0.0094, -0.2105,  0.2834,\n",
       "         -1.0303,  0.8815,  0.4692, -0.1531,  0.9641, -1.3044, -0.4251, -0.5490,\n",
       "         -0.6301,  1.2259,  1.4332, -1.4505,  2.0888, -0.0059,  1.5277,  1.4730],\n",
       "        [ 0.1677,  1.0228, -0.4538, -0.0862,  1.9249,  2.9601,  1.1000, -0.2760,\n",
       "         -0.0615, -0.1409, -0.2373, -1.5995, -1.1702,  1.4656, -1.3063, -0.9157,\n",
       "         -0.9066, -0.4793,  0.3760,  0.4974,  0.0229, -0.4541, -0.4813, -1.0898,\n",
       "          0.8615, -0.2514, -1.2457,  1.1001, -0.2081,  0.3905,  0.5391, -1.0650],\n",
       "        [-0.7000, -0.3955, -0.0049,  1.2113,  1.6559,  1.6313,  0.0602,  0.8077,\n",
       "         -1.4711, -0.2496, -0.3276,  0.6451,  0.7196, -0.3149, -1.3477,  1.2001,\n",
       "         -0.3195, -0.2187, -0.7321, -0.9040,  0.1070, -1.2493, -1.0409,  1.8514,\n",
       "          1.0325,  0.6299, -0.1472, -1.5695, -0.3887, -0.6763, -1.3172,  1.8227]],\n",
       "       grad_fn=<SelectBackward>)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x1[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x11da2a7f0>"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.manual_seed(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "#scrap\n",
    "NHEADS = 4\n",
    "DMODEL = 32\n",
    "DINNER = 17"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "#scrap\n",
    "mha = MultiHeadAttention(32, 17, n_heads=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MultiHeadAttention(\n",
       "  (linear_kv): Linear(in_features=32, out_features=136, bias=False)\n",
       "  (linear_q): Linear(in_features=32, out_features=68, bias=False)\n",
       "  (linear_p): Linear(in_features=32, out_features=68, bias=False)\n",
       "  (dropa): Dropout(p=0.0)\n",
       "  (lout): Linear(in_features=68, out_features=32, bias=False)\n",
       "  (norm): LayerNorm(torch.Size([32]), eps=1e-05, elementwise_affine=True)\n",
       "  (dropo): Dropout(p=0.1)\n",
       ")"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#scrap\n",
    "mha"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "#scrap\n",
    "inpt = torch.rand(7, 3, 32)\n",
    "pos = torch.rand(13, 32)\n",
    "mem = torch.rand(6, 3, 32)\n",
    "u, v = torch.rand(4, 17), torch.rand(4, 17)\n",
    "x2 = mha(inpt, pos, mem, u, v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.3604, -1.0046,  0.0095, -0.3961,  0.0042,  1.4597,  0.8075, -0.8257,\n",
       "         -2.1602, -0.6140, -0.5043, -0.5235, -1.1914, -0.0094, -0.2105,  0.2834,\n",
       "         -1.0303,  0.8815,  0.4692, -0.1531,  0.9641, -1.3044, -0.4251, -0.5490,\n",
       "         -0.6301,  1.2259,  1.4332, -1.4505,  2.0888, -0.0059,  1.5277,  1.4730],\n",
       "        [ 0.1677,  1.0228, -0.4538, -0.0862,  1.9249,  2.9601,  1.1000, -0.2760,\n",
       "         -0.0615, -0.1409, -0.2373, -1.5995, -1.1702,  1.4656, -1.3063, -0.9157,\n",
       "         -0.9066, -0.4793,  0.3760,  0.4974,  0.0229, -0.4541, -0.4813, -1.0898,\n",
       "          0.8615, -0.2514, -1.2457,  1.1001, -0.2081,  0.3905,  0.5391, -1.0650],\n",
       "        [-0.7000, -0.3955, -0.0049,  1.2113,  1.6559,  1.6313,  0.0602,  0.8077,\n",
       "         -1.4711, -0.2496, -0.3276,  0.6451,  0.7196, -0.3149, -1.3477,  1.2001,\n",
       "         -0.3195, -0.2187, -0.7321, -0.9040,  0.1070, -1.2493, -1.0409,  1.8514,\n",
       "          1.0325,  0.6299, -0.1472, -1.5695, -0.3887, -0.6763, -1.3172,  1.8227]],\n",
       "       grad_fn=<SelectBackward>)"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#scrap\n",
    "x2[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0., grad_fn=<MeanBackward1>)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#scrap\n",
    "x2.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(1.0007, grad_fn=<StdBackward0>)"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#scrap\n",
    "x2.std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x11da2a7f0>"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#scrap\n",
    "torch.manual_seed(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "#scrap\n",
    "mha_ref = RelPartialLearnableMultiHeadAttn(NHEADS, DMODEL, DINNER, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([7, 3, 32])"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#scrap\n",
    "mha_ref(inpt, pos, u, v, mems=mem).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(1.7030e-08, grad_fn=<MeanBackward1>)"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#scrap\n",
    "mha_ref(inpt, pos, u, v, mems=mem).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(1.0007, grad_fn=<StdBackward0>)"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#scrap\n",
    "mha_ref(inpt, pos, u, v, mems=mem).std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Periphereal items"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0, 0, 0],\n",
       "        [0, 1, 2],\n",
       "        [0, 2, 4],\n",
       "        [0, 3, 6]])"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#scrap\n",
    "torch.einsum(\"i,j->ij\", torch.arange(4), torch.arange(3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1.0000e+00, 5.6234e-01, 3.1623e-01, 1.7783e-01, 1.0000e-01, 5.6234e-02,\n",
       "        3.1623e-02, 1.7783e-02, 1.0000e-02, 5.6234e-03, 3.1623e-03, 1.7783e-03,\n",
       "        1.0000e-03, 5.6234e-04, 3.1623e-04, 1.7783e-04])"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "1 / (10000 ** (torch.arange(0.0, 32, 2.0) / 32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PositionalEmbedding(nn.Module):\n",
    "    def __init__(self, d):\n",
    "        super().__init__()\n",
    "        self.d = d\n",
    "        inv_freq = 1 / (10000 ** (torch.arange(0.0, d, 2.0) / d))\n",
    "        self.register_buffer(\"inv_freq\", inv_freq)\n",
    "        \n",
    "    def forward(self, positions: torch.LongTensor, # (seq, )\n",
    "               ):\n",
    "        sinusoid_inp = torch.einsum(\"i,j->ij\", positions.float(), self.inv_freq)\n",
    "        pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)\n",
    "        return pos_emb[:,None,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([10, 1, 32])"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embedding = PositionalEmbedding(32)\n",
    "embedding(torch.arange(10).float()).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PositionwiseFF(nn.Module):\n",
    "    def __init__(self, d_input, d_inner, dropout):\n",
    "        super().__init__()\n",
    "\n",
    "        self.d_input = d_input\n",
    "        self.d_inner = d_inner\n",
    "        self.dropout = dropout\n",
    "        self.ff = nn.Sequential(\n",
    "            nn.Linear(d_input, d_inner), nn.ReLU(inplace=True),\n",
    "            nn.Dropout(dropout),\n",
    "            nn.Linear(d_inner, d_input),\n",
    "            nn.Dropout(dropout),\n",
    "        )\n",
    "        self.layer_norm = nn.LayerNorm(d_input)\n",
    "\n",
    "    def forward(self, input_: torch.FloatTensor, # (cur_seq, bs, d_input)\n",
    "               ) -> torch.FloatTensor: # (cur_seq, bs, d_input)\n",
    "        ff_out = self.ff(input_)\n",
    "        output = self.layer_norm(input_ + ff_out)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mem_transformer import PositionwiseFF as RefPositionwiseFF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Building the decoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DecoderBlock(nn.Module):\n",
    "    def __init__(self, n_heads, d_input, \n",
    "                 d_head_inner, d_ff_inner,\n",
    "                 dropout, dropouta=0.):\n",
    "        super().__init__()\n",
    "        self.mha = MultiHeadAttention(d_input, d_head_inner, n_heads=n_heads, \n",
    "                                      dropout=dropout, dropouta=dropouta)\n",
    "        self.ff = PositionwiseFF(d_input, d_ff_inner, dropout)\n",
    "            \n",
    "    def forward(self, input_: torch.FloatTensor, # (cur_seq, bs, d_input)\n",
    "                pos_embs: torch.FloatTensor, # (cur_seq + prev_seq, d_input),\n",
    "#                 memory: torch.FloatTensor, # (cur_seq ),\n",
    "                u: torch.FloatTensor, # (H, d_input),  # TODO: is this this level?\n",
    "                v: torch.FloatTensor, # (H, d_input),\n",
    "                mask=None,\n",
    "                mems=None,\n",
    "               ):\n",
    "        return self.ff(self.mha(input_, pos_embs, mems, u, v, mask=mask))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Building the adaptive embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "class StandardWordEmbedding(nn.Module):\n",
    "    \"\"\"\n",
    "    TODO: Implement??\n",
    "    \"\"\"\n",
    "    def __init__(self, num_embeddings, embedding_dim,\n",
    "                div_val=1, sample_softmax=False):\n",
    "        super().__init__()\n",
    "        self.num_embeddings = num_embeddings\n",
    "        self.embedding_dim = embedding_dim\n",
    "        self.embedding = nn.Embedding(num_embeddings, embedding_dim)\n",
    "        self.scale = embedding_dim ** 0.5\n",
    "\n",
    "    def forward(self, input_: torch.LongTensor):\n",
    "        return self.embedding(input_) * self.scale"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "idxs1 = torch.randint(100, (7, 3))\n",
    "idxs2 = torch.randint(100, (6, 3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "wembs = StandardWordEmbedding(100, 32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([7, 3, 32])"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wembs(idxs1).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mem_transformer import AdaptiveEmbedding as RefAdaptiveEmbedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mem_transformer import ProjectedAdaptiveLogSoftmax as RefProjectedAdaptiveLogSoftmax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Building the entire model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "TODO: Handle evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mem_transformer import RelPartialLearnableDecoderLayer as RefDecoderLayer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import *\n",
    "\n",
    "class TransformerXL(nn.Module):\n",
    "    def __init__(self, num_embeddings, n_layers, n_heads, \n",
    "                 d_model, d_head_inner, d_ff_inner,\n",
    "                 dropout=0.1, dropouta=0., \n",
    "                 seq_len: int=0, mem_len: int=0):\n",
    "        super().__init__()\n",
    "        self.n_layers,self.n_heads,self.d_model,self.d_head_inner,self.d_ff_inner = \\\n",
    "            n_layers,n_heads,d_model,d_head_inner,d_ff_inner\n",
    "        # Embedding layers\n",
    "        self.word_embs = StandardWordEmbedding(num_embeddings, d_model)\n",
    "        self.pos_embs = PositionalEmbedding(d_model)\n",
    "        # Core transformer\n",
    "        self.drop = nn.Dropout(dropout)\n",
    "        self.layers = nn.ModuleList([DecoderBlock(n_heads, d_model, d_head_inner=d_head_inner,\n",
    "                                                  d_ff_inner=d_ff_inner,\n",
    "                                                  dropout=dropout, dropouta=dropouta)\n",
    "                                     for _ in range(n_layers)])\n",
    "\n",
    "        # tie weights\n",
    "        self.output_projection = nn.Linear(d_model, num_embeddings)\n",
    "        self.output_projection.weight = self.word_embs.embedding.weight\n",
    "        self.loss_fn = nn.CrossEntropyLoss() # TODO: Why do we need a special loss?\n",
    "\n",
    "        self.seq_len, self.mem_len = seq_len, mem_len # TODO: Is seq_len being used?\n",
    "        \n",
    "        # TODO: Why is this shared among the layers and heads?\n",
    "        # TODO: Better understand meaning of these parameters\n",
    "        self.u, self.v = (nn.Parameter(torch.Tensor(self.n_heads, self.d_head_inner)),\n",
    "                          nn.Parameter(torch.Tensor(self.n_heads, self.d_head_inner)))\n",
    "        \n",
    "    def init_memory(self, device=torch.device(\"cpu\")) -> torch.FloatTensor:\n",
    "        return [torch.empty(0, dtype=torch.float).to(device) for _ in range(self.n_layers+1)]\n",
    "    \n",
    "    def update_memory(self, \n",
    "            previous_memory: List[torch.FloatTensor], \n",
    "            hidden_states: List[torch.FloatTensor],\n",
    "        ):\n",
    "        assert len(hidden_states) == len(previous_memory)\n",
    "        mem_len, seq_len = previous_memory[0].size(0), hidden_states[0].size(0)\n",
    "\n",
    "        # For the updated memory, we use the most recent `self.mem_len`\n",
    "        # states, including the previous memory\n",
    "        # In other words, if `seq_len` < `self.mem_len` some of the previous memory\n",
    "        # will carry over to the next memory\n",
    "        with torch.no_grad():\n",
    "            new_memory = []\n",
    "            end_idx = mem_len + seq_len\n",
    "            beg_idx = max(0, end_idx - self.mem_len)\n",
    "            # TODO: Make this more efficient\n",
    "            for m, h in zip(previous_memory, hidden_states):\n",
    "                cat = torch.cat([m, h], dim=0) # (mem_len + seq_len, bs, d)\n",
    "                new_memory.append(cat[beg_idx:end_idx].detach()) # (self.mem_len, bs, d)\n",
    "        return new_memory\n",
    "    \n",
    "    def reset_length(self, seq_len, ext_len, mem_len):\n",
    "        self.seq_len = seq_len\n",
    "        self.mem_len = mem_len\n",
    "    \n",
    "    def forward(self, idxs: torch.LongTensor, # (cs, bs)\n",
    "                target: torch.LongTensor, # (cs, bs) -> TODO: Isn:'t this the same?\n",
    "                memory: Optional[List[torch.FloatTensor]]=None,\n",
    "               ) -> Dict[str, torch.Tensor]:\n",
    "        if memory is None: \n",
    "            memory: List[torch.FloatTensor] = self.init_memory(idxs.device)\n",
    "        assert len(memory) == len(self.layers) + 1\n",
    "        cur_seq, bs = idxs.size()\n",
    "        prev_seq = memory[0].size(0)\n",
    "        \n",
    "        # Construct attention mask (TODO: Understand)\n",
    "        dec_attn_mask = torch.triu(\n",
    "            torch.ones((cur_seq, cur_seq + prev_seq)),\n",
    "            diagonal=1 + prev_seq,\n",
    "        ).byte()[...,None].to(idxs.device)\n",
    "        \n",
    "        word_embs = self.drop(self.word_embs(idxs))\n",
    "        # TODO: Understand\n",
    "        pos_idxs = torch.arange(cur_seq + prev_seq - 1, -1, -1.0, dtype=torch.float).to(word_embs.device)\n",
    "        pos_embs = self.drop(self.pos_embs(pos_idxs))\n",
    "        \n",
    "        # Main part of forward pass\n",
    "        hidden_states = [word_embs]\n",
    "        layer_out = word_embs\n",
    "        for mem, layer in zip(memory, self.layers):\n",
    "            layer_out = layer(layer_out, pos_embs, self.u, self.v, \n",
    "                              mask=dec_attn_mask, mems=mem)\n",
    "            hidden_states.append(layer_out)\n",
    "        \n",
    "        logits = self.output_projection(self.drop(layer_out))        \n",
    "        loss = self.loss_fn(logits.view(-1, logits.size(-1)), target.view(-1))\n",
    "        \n",
    "        # Update memory \n",
    "        # Ensure the memory is treated as a constant\n",
    "        # and we do not back propagate through them\n",
    "        new_memory = self.update_memory(memory, hidden_states)\n",
    "        return {\"loss\": loss, \"logits\": logits, \"memory\": new_memory}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(1.7030e-08, grad_fn=<MeanBackward1>)"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mha_ref(inpt, pos, u, v, mems=mem).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformer = TransformerXL(1000, 4, 3, 32, 17, 71, mem_len=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'loss': tensor(22.5404, grad_fn=<NllLossBackward>),\n",
       " 'logits': tensor([[[-2.9090e+00, -6.6362e+00, -6.9902e+00,  ...,  1.0316e+00,\n",
       "            3.1874e+00,  1.3131e-01],\n",
       "          [-5.6880e+00, -3.6444e+00,  1.0460e+01,  ...,  1.1932e+00,\n",
       "           -7.6286e+00, -3.2345e+00],\n",
       "          [-2.9856e+00, -4.9486e+00,  1.1710e+01,  ..., -5.2753e+00,\n",
       "            6.9196e-01,  1.0313e+01],\n",
       "          ...,\n",
       "          [ 2.9315e+00, -1.8000e+00,  8.8548e+00,  ...,  3.9414e+00,\n",
       "           -3.2021e+00,  6.2986e-01],\n",
       "          [ 1.2698e+00,  1.6134e+00,  1.4195e+01,  ...,  7.5257e+00,\n",
       "           -2.0149e+01, -1.2188e+00],\n",
       "          [-4.9988e+00, -5.4103e+00,  1.0239e+01,  ...,  2.6494e+00,\n",
       "           -8.9913e-01,  3.5041e-01]],\n",
       " \n",
       "         [[ 5.4079e+00, -2.9023e+00,  3.8685e+00,  ..., -1.2428e+00,\n",
       "            2.6299e+00,  6.1176e+00],\n",
       "          [ 3.0582e+00, -2.8128e+00,  6.3073e+00,  ...,  6.2343e+00,\n",
       "           -5.1910e+00,  2.3591e+00],\n",
       "          [-5.1048e+00, -1.1722e+01,  1.8561e+01,  ..., -1.3006e+01,\n",
       "            3.0591e+00, -5.4180e+00],\n",
       "          ...,\n",
       "          [-1.3049e+01,  4.6116e+00,  6.4100e+00,  ..., -4.0805e+00,\n",
       "           -1.2703e+01,  3.2632e+00],\n",
       "          [-3.8291e+00,  2.6759e+00,  3.9214e+00,  ..., -1.3588e+00,\n",
       "           -1.0968e+01,  5.1822e+00],\n",
       "          [-3.6167e+00, -4.9403e-01,  2.7270e+00,  ..., -3.1145e-01,\n",
       "            3.0889e+00, -3.0946e+00]],\n",
       " \n",
       "         [[-1.1285e+01,  7.1888e-01,  1.0786e+01,  ..., -8.5734e+00,\n",
       "           -7.2422e+00,  1.1526e+00],\n",
       "          [-9.7996e+00, -2.5140e+00,  7.9750e+00,  ...,  8.8998e+00,\n",
       "           -5.7018e+00, -2.3594e+00],\n",
       "          [ 2.7738e+00, -4.5189e+00,  2.4715e+00,  ...,  2.5420e+00,\n",
       "            9.4204e+00,  7.0859e+00],\n",
       "          ...,\n",
       "          [-1.6867e+00, -6.4232e+00,  2.4668e+00,  ..., -1.2128e+01,\n",
       "            1.5371e+00, -1.1126e+01],\n",
       "          [-1.4214e+00,  4.9909e+00, -2.1618e+00,  ..., -5.4908e+00,\n",
       "           -5.0910e+00,  5.6196e+00],\n",
       "          [-3.1989e+00,  4.5923e+00,  7.6753e+00,  ...,  3.2353e+00,\n",
       "           -2.7371e+00,  1.0900e-01]],\n",
       " \n",
       "         [[ 7.2187e+00, -4.9161e+00,  1.7907e-01,  ...,  2.1525e+00,\n",
       "           -1.3839e+00, -6.0559e+00],\n",
       "          [-5.0362e-01,  2.6605e+00, -2.7521e+00,  ..., -2.9869e+00,\n",
       "           -1.7448e+00,  5.9035e+00],\n",
       "          [ 4.2782e+00, -8.0596e-03,  4.6575e+00,  ..., -5.7534e+00,\n",
       "            2.7254e+00,  2.4870e+00],\n",
       "          ...,\n",
       "          [ 7.5153e+00, -7.7140e+00, -3.4272e+00,  ..., -1.7695e+00,\n",
       "            6.9442e+00, -5.5795e+00],\n",
       "          [ 4.6189e+00,  2.0843e-01, -1.5610e+00,  ..., -8.8573e+00,\n",
       "           -9.6588e+00, -2.0796e-01],\n",
       "          [ 2.8136e+00,  5.7016e+00, -1.2385e+01,  ..., -2.0035e+00,\n",
       "            5.8413e+00,  3.3980e+00]],\n",
       " \n",
       "         [[-2.7728e+00, -8.3331e+00,  5.5866e+00,  ..., -5.6527e+00,\n",
       "            1.5887e+00, -1.2220e+01],\n",
       "          [ 2.8744e-01,  4.7797e+00, -8.1315e+00,  ...,  4.9264e+00,\n",
       "            6.8310e-02,  6.7241e+00],\n",
       "          [-3.1074e+00, -3.9053e+00,  8.5668e+00,  ..., -4.6804e+00,\n",
       "            5.5736e+00,  1.1375e+01],\n",
       "          ...,\n",
       "          [ 1.3142e+00, -1.1708e+00, -6.9842e+00,  ..., -2.1300e+00,\n",
       "           -2.3489e+00, -6.0119e-01],\n",
       "          [ 9.9028e+00,  1.0158e+01, -8.9404e+00,  ..., -5.2356e+00,\n",
       "           -2.6794e+00,  1.0648e+01],\n",
       "          [-3.6775e-01,  4.8810e+00,  6.6016e+00,  ...,  1.4048e+00,\n",
       "           -5.4985e-01,  7.5897e-01]]], grad_fn=<AddBackward0>),\n",
       " 'memory': [tensor([[[ -7.9165,   2.8472,  -6.5213,  ...,  -3.6661,   9.1968,   9.0510],\n",
       "           [ 10.8699,   3.7225,  -1.5479,  ...,   3.4311,  -4.5358,   4.6721],\n",
       "           [-12.3604,   5.1361,   1.0344,  ..., -11.3939,  -5.0266,   1.3282],\n",
       "           ...,\n",
       "           [  1.1281,  -0.0000,   3.6624,  ...,   7.7196,   4.5657,   1.1970],\n",
       "           [ -1.2241,   1.6966,   4.6039,  ...,  -4.5270,  -0.1845,   4.9344],\n",
       "           [  3.4288,  -8.8963,   5.4705,  ...,  -4.0037,  14.2677,  -7.1428]],\n",
       "  \n",
       "          [[ -0.9367,   1.0206,  -4.7811,  ...,  10.8616,  -1.4185, -10.4317],\n",
       "           [ -9.4666,   0.0000,   1.0812,  ...,   9.6719,   4.1617,   6.5190],\n",
       "           [  6.4290,  -8.0596,  -0.3519,  ..., -10.6633,  -2.7553,  -9.2060],\n",
       "           ...,\n",
       "           [-10.0291,   0.6848,   8.4795,  ...,   2.5703,  -7.9746,  -6.9063],\n",
       "           [ -0.0000,   0.0000,  -1.2143,  ...,   2.8127,   4.4696,  -4.3683],\n",
       "           [ -0.0000,   4.0421,  -0.0000,  ...,   0.0801,   0.6824,   7.2616]],\n",
       "  \n",
       "          [[  2.3131,   8.6414,  -3.8427,  ...,  -5.0005,  -4.4974, -12.0917],\n",
       "           [  0.0000,  -8.8963,   5.4705,  ...,  -4.0037,  14.2677,  -7.1428],\n",
       "           [ -5.7094,  -1.1246,  -4.6402,  ..., -11.0839,   3.2963,   1.5237],\n",
       "           ...,\n",
       "           [  8.0156, -20.6471,   1.4269,  ...,   1.1159,   2.9563,   0.6317],\n",
       "           [ -2.8927,  -3.7861,   9.6538,  ...,  -0.0000,  -9.7765,   5.5566],\n",
       "           [ -1.6402,  -9.7288,   2.9966,  ...,  -3.0852,   4.9920,  -0.2886]],\n",
       "  \n",
       "          [[  6.4607,   0.6544,   5.9023,  ...,   2.7852,  -0.0000,   5.1326],\n",
       "           [-10.9169,  -2.7434,   1.4744,  ...,   0.0000,  -2.2983,  -3.8839],\n",
       "           [  2.4380,  -3.9097,  -5.0085,  ...,  -6.1964,  -0.6191,   1.1846],\n",
       "           ...,\n",
       "           [  9.6341,  11.1961,   1.7214,  ...,  -0.0000,  -6.3014,   7.1319],\n",
       "           [ -8.2760,   2.0107,   0.0000,  ...,  -3.5534,   7.6721,  -4.8233],\n",
       "           [  0.1081,   4.0695,  -1.7937,  ...,  13.0986,  -2.7190,   4.2481]],\n",
       "  \n",
       "          [[ 14.0452,  -9.3452,  -5.6232,  ...,  -0.0000,   3.6172,   0.0000],\n",
       "           [  7.0976,   1.9510,   0.1968,  ...,  -5.9646,   9.5041,  -7.8056],\n",
       "           [ -8.6774,   0.5205,  -7.2291,  ...,  -3.7007,  -1.4914,  -3.2516],\n",
       "           ...,\n",
       "           [  6.1078,  -8.3832, -18.7803,  ...,  -5.7648,   5.0374,   1.3115],\n",
       "           [  2.0860,  -7.0435,  -3.5012,  ...,   7.8078,  -0.0000,   2.2856],\n",
       "           [  7.4281,  -5.8197,  -1.3763,  ...,  14.1871,  -9.0098,  -2.4723]]]),\n",
       "  tensor([[[-1.4075e+00,  8.6982e-01, -1.2603e+00,  ..., -4.7981e-01,\n",
       "             1.6762e+00,  8.7335e-01],\n",
       "           [ 1.8234e+00,  1.2566e-01, -1.4367e-01,  ...,  3.2705e-01,\n",
       "            -1.4974e+00,  4.6689e-01],\n",
       "           [-1.8269e+00,  9.6260e-01,  3.0105e-01,  ..., -1.2232e+00,\n",
       "            -5.0117e-01, -2.1643e-01],\n",
       "           ...,\n",
       "           [ 5.7797e-01, -6.9415e-01,  1.2242e+00,  ...,  4.9676e-01,\n",
       "             3.0435e-01, -1.9372e-01],\n",
       "           [-5.6056e-01,  2.9866e-01,  9.4531e-01,  ..., -1.1827e+00,\n",
       "            -1.1301e-02,  1.0765e+00],\n",
       "           [-2.5707e-01, -1.7380e+00,  5.2994e-01,  ..., -6.2729e-01,\n",
       "             2.9025e+00, -8.9594e-01]],\n",
       "  \n",
       "          [[-3.5124e-02,  2.0652e-01, -1.2252e+00,  ...,  1.8766e+00,\n",
       "             8.5698e-01, -1.3338e+00],\n",
       "           [-2.5701e+00, -1.1458e+00, -3.1290e-01,  ...,  5.2470e-01,\n",
       "            -1.2858e-02,  1.8325e-01],\n",
       "           [ 7.9088e-01, -1.0715e+00,  4.1837e-01,  ..., -6.5857e-01,\n",
       "            -9.4773e-01, -1.9623e+00],\n",
       "           ...,\n",
       "           [-1.8101e+00,  6.1763e-01,  2.4034e+00,  ...,  2.3469e-01,\n",
       "            -1.1505e+00, -7.4488e-01],\n",
       "           [-9.7368e-01, -5.6935e-01, -7.2813e-01,  ...,  4.2996e-01,\n",
       "             6.6966e-01, -1.0146e+00],\n",
       "           [-1.0541e-01,  6.8592e-01, -2.8538e-01,  ..., -8.2770e-01,\n",
       "             5.9636e-01,  1.0510e+00]],\n",
       "  \n",
       "          [[ 3.5516e-01,  1.2924e+00, -6.8353e-01,  ..., -4.9510e-01,\n",
       "            -1.9995e-01, -1.7844e+00],\n",
       "           [ 2.5999e-01, -1.3082e+00,  6.9291e-01,  ..., -8.8368e-01,\n",
       "             2.7950e+00, -9.3486e-01],\n",
       "           [-5.8816e-01, -3.1581e-01, -4.1483e-01,  ..., -1.5056e+00,\n",
       "             3.9830e-01,  2.5137e-01],\n",
       "           ...,\n",
       "           [-2.3666e-01, -3.0995e+00,  3.9794e-01,  ...,  3.7001e-01,\n",
       "             4.5640e-01,  2.8671e-01],\n",
       "           [-6.7484e-01, -1.1602e+00,  1.5083e+00,  ..., -7.2485e-01,\n",
       "            -1.7365e+00,  5.6122e-01],\n",
       "           [-2.9310e-01, -1.4260e+00,  6.4370e-01,  ..., -8.2309e-01,\n",
       "             8.0097e-01,  1.0626e-01]],\n",
       "  \n",
       "          [[ 3.7087e-01,  4.3359e-01,  7.8817e-01,  ..., -2.6943e-01,\n",
       "             3.1099e-01,  1.2151e+00],\n",
       "           [-1.4556e+00, -6.2030e-01,  1.1496e+00,  ..., -3.9742e-01,\n",
       "             1.3703e-01, -5.2389e-01],\n",
       "           [-4.3266e-02, -4.9883e-01, -1.9812e-01,  ...,  9.9986e-02,\n",
       "            -1.0681e+00,  1.4361e-01],\n",
       "           ...,\n",
       "           [ 1.4534e+00,  1.1003e+00,  1.1457e+00,  ..., -4.8818e-01,\n",
       "            -1.9578e+00,  2.1714e-01],\n",
       "           [-2.0940e+00,  5.3471e-01, -6.3326e-01,  ..., -1.0996e+00,\n",
       "             1.6551e+00, -1.3830e+00],\n",
       "           [-9.1824e-01,  1.1756e-01, -1.1049e+00,  ...,  8.6140e-01,\n",
       "             3.7476e-01,  5.5450e-01]],\n",
       "  \n",
       "          [[ 1.9232e+00, -1.5901e+00, -1.0685e+00,  ..., -3.3199e-01,\n",
       "             3.3837e-01,  6.4744e-01],\n",
       "           [ 8.2449e-01, -3.5109e-02,  5.0434e-01,  ..., -1.6562e+00,\n",
       "             2.3139e+00, -1.7043e-01],\n",
       "           [-1.1982e+00,  3.9173e-01, -8.7112e-01,  ...,  3.3533e-01,\n",
       "            -1.1350e+00, -2.1864e-01],\n",
       "           ...,\n",
       "           [ 1.6989e+00, -1.2475e+00, -2.2331e+00,  ..., -1.2556e+00,\n",
       "             5.1797e-01, -4.8660e-01],\n",
       "           [ 4.1342e-02, -1.2929e+00, -1.9962e-01,  ...,  1.5845e+00,\n",
       "             1.7763e-01, -7.4199e-04],\n",
       "           [ 3.2834e-01, -1.6274e+00, -8.0584e-01,  ...,  1.9202e+00,\n",
       "            -7.4953e-01, -2.1064e-01]]]),\n",
       "  tensor([[[-1.3306e+00,  1.3007e+00, -1.4306e+00,  ..., -2.7851e-01,\n",
       "             1.6702e+00,  1.8830e-01],\n",
       "           [ 2.1571e+00,  1.1651e+00, -1.9429e-01,  ..., -4.6347e-01,\n",
       "            -1.4137e+00, -1.1520e-01],\n",
       "           [-1.5468e+00,  5.7542e-01,  8.8219e-02,  ..., -1.4885e+00,\n",
       "            -4.6584e-01, -3.7478e-01],\n",
       "           ...,\n",
       "           [ 5.6825e-01, -5.7777e-01,  1.1639e+00,  ..., -5.6258e-01,\n",
       "             6.0227e-01, -9.3269e-01],\n",
       "           [-4.4351e-01,  4.5954e-01,  4.5132e-01,  ..., -1.4899e+00,\n",
       "            -7.0064e-02,  7.0350e-01],\n",
       "           [-2.1070e-04, -1.3854e+00,  3.1747e-01,  ..., -1.0702e+00,\n",
       "             2.3733e+00, -7.6241e-01]],\n",
       "  \n",
       "          [[ 1.9159e-01,  6.2174e-01, -1.4339e+00,  ...,  1.1361e+00,\n",
       "             8.3255e-01, -1.5696e+00],\n",
       "           [-2.2841e+00, -5.6383e-01, -4.6842e-01,  ...,  3.2079e-01,\n",
       "             4.5897e-01, -4.0457e-01],\n",
       "           [ 7.8848e-01, -6.8831e-01,  3.9074e-02,  ..., -1.2791e+00,\n",
       "            -9.3222e-01, -1.7748e+00],\n",
       "           ...,\n",
       "           [-1.8581e+00,  8.1504e-01,  2.2371e+00,  ..., -1.0834e+00,\n",
       "            -8.1911e-01, -1.1646e+00],\n",
       "           [-1.1125e+00, -2.1881e-01, -1.6452e+00,  ...,  1.7361e-01,\n",
       "             3.7580e-01, -7.6474e-01],\n",
       "           [ 7.3826e-01,  7.3637e-01, -1.7424e-01,  ..., -8.3447e-01,\n",
       "             3.8935e-01,  7.4248e-01]],\n",
       "  \n",
       "          [[ 7.4769e-01,  1.7709e+00, -5.0933e-01,  ..., -5.8959e-01,\n",
       "            -5.1849e-01, -2.0002e+00],\n",
       "           [ 5.6208e-01, -1.2316e+00,  5.5854e-01,  ..., -9.9534e-01,\n",
       "             2.9654e+00, -1.2353e+00],\n",
       "           [-4.1333e-01, -1.0097e+00, -6.3980e-01,  ..., -1.5157e+00,\n",
       "            -8.6407e-02,  1.7362e-01],\n",
       "           ...,\n",
       "           [-2.5177e-01, -2.6146e+00,  1.3108e-01,  ..., -5.9070e-01,\n",
       "             5.7003e-01,  2.9616e-01],\n",
       "           [-3.5587e-01, -1.1980e+00,  6.1488e-01,  ..., -8.1254e-01,\n",
       "            -1.5807e+00,  3.7188e-01],\n",
       "           [-1.2082e-01, -1.3386e+00,  2.7545e-01,  ..., -1.0739e+00,\n",
       "             4.6475e-01, -3.7650e-01]],\n",
       "  \n",
       "          [[ 4.7793e-01,  3.7573e-01,  8.5671e-01,  ...,  4.7002e-02,\n",
       "             1.4923e-02,  6.9506e-01],\n",
       "           [-1.2352e+00,  8.5222e-02,  6.1632e-01,  ..., -1.4276e+00,\n",
       "             2.8734e-01, -6.4834e-01],\n",
       "           [ 2.9819e-01, -3.8667e-01, -4.3264e-01,  ..., -3.1432e-03,\n",
       "            -8.4815e-01,  2.4908e-01],\n",
       "           ...,\n",
       "           [ 1.6693e+00,  1.0637e+00,  1.1771e+00,  ..., -1.2600e+00,\n",
       "            -1.5462e+00,  6.2565e-02],\n",
       "           [-2.1533e+00,  6.9254e-01, -7.4350e-01,  ..., -1.0564e+00,\n",
       "             1.3179e+00, -1.2193e+00],\n",
       "           [-8.8244e-01,  1.2114e-01, -1.9055e+00,  ...,  8.5794e-01,\n",
       "             3.3664e-01,  4.4276e-01]],\n",
       "  \n",
       "          [[ 1.8546e+00, -1.3480e+00, -7.5729e-01,  ..., -1.0114e+00,\n",
       "             6.6011e-01,  1.1406e-01],\n",
       "           [ 1.0778e+00, -2.4359e-01,  6.0551e-01,  ..., -2.1570e+00,\n",
       "             2.6368e+00, -6.2534e-01],\n",
       "           [-9.8032e-01,  4.4240e-02, -1.4533e+00,  ..., -6.1588e-01,\n",
       "            -9.5568e-01,  6.0436e-03],\n",
       "           ...,\n",
       "           [ 1.2994e+00, -2.1177e-01, -2.4694e+00,  ..., -1.7859e+00,\n",
       "             4.6789e-01, -1.5110e+00],\n",
       "           [-4.6290e-01, -1.2909e+00, -3.7981e-01,  ...,  1.2334e+00,\n",
       "             2.5336e-02,  1.0665e-01],\n",
       "           [ 6.0746e-01, -1.4156e+00, -6.2642e-01,  ...,  1.8899e+00,\n",
       "            -8.5231e-01, -7.4719e-01]]]),\n",
       "  tensor([[[-1.5884e+00,  6.1805e-01, -1.2781e+00,  ..., -3.9332e-01,\n",
       "             1.7109e+00,  7.0903e-01],\n",
       "           [ 1.9797e+00,  5.8025e-01,  3.1218e-01,  ..., -2.9067e-01,\n",
       "            -1.4205e+00,  5.3232e-02],\n",
       "           [-1.4291e+00, -9.4496e-04,  3.5576e-03,  ..., -1.3392e+00,\n",
       "            -4.4792e-01, -1.1582e+00],\n",
       "           ...,\n",
       "           [ 4.3028e-01, -4.3230e-01,  8.5661e-01,  ..., -8.3670e-01,\n",
       "            -2.9372e-01, -1.1582e+00],\n",
       "           [-6.4740e-01,  1.6148e+00,  4.2838e-01,  ..., -1.1808e+00,\n",
       "            -1.7271e-01,  1.5395e-01],\n",
       "           [-2.9026e-01, -1.2072e+00,  1.9272e-01,  ..., -6.3202e-01,\n",
       "             1.9062e+00, -1.0649e-01]],\n",
       "  \n",
       "          [[-1.2886e-01,  3.1420e-01, -1.7663e+00,  ...,  7.3960e-01,\n",
       "             1.2968e+00, -8.9929e-01],\n",
       "           [-2.5131e+00, -5.7835e-01, -3.2877e-01,  ...,  5.5650e-01,\n",
       "             1.0952e+00,  1.9868e-01],\n",
       "           [ 9.4452e-01, -1.0268e+00,  8.1470e-02,  ..., -8.2131e-01,\n",
       "            -5.1620e-01, -1.6072e+00],\n",
       "           ...,\n",
       "           [-2.6731e+00,  8.3889e-02,  1.7946e+00,  ..., -1.1256e+00,\n",
       "            -6.0699e-01, -1.7725e+00],\n",
       "           [-2.0159e+00,  4.4735e-01, -1.3473e+00,  ...,  5.0926e-01,\n",
       "             4.4100e-01, -1.5061e+00],\n",
       "           [ 8.7373e-01,  2.7490e-01, -1.0158e-01,  ..., -1.0099e+00,\n",
       "            -3.1954e-01,  1.0238e+00]],\n",
       "  \n",
       "          [[ 6.5599e-01,  9.3250e-01, -6.3902e-01,  ..., -1.2069e+00,\n",
       "            -2.5376e-01, -1.8317e+00],\n",
       "           [ 1.8772e-01, -1.1102e+00,  2.8304e-01,  ..., -8.6429e-01,\n",
       "             2.6687e+00, -9.5265e-01],\n",
       "           [-5.1257e-01, -9.8227e-01, -7.3705e-01,  ..., -1.2598e+00,\n",
       "            -7.6278e-01, -4.5666e-02],\n",
       "           ...,\n",
       "           [-4.7853e-01, -2.4969e+00,  2.2361e-01,  ..., -3.7042e-01,\n",
       "             6.4116e-01, -3.8228e-01],\n",
       "           [-4.7250e-01, -1.5545e+00,  5.3627e-01,  ..., -2.1755e-02,\n",
       "            -1.3120e+00,  2.1655e-01],\n",
       "           [-1.0030e-01, -2.0830e+00,  3.8006e-01,  ..., -1.5167e+00,\n",
       "             1.5243e-01,  7.1553e-01]],\n",
       "  \n",
       "          [[ 1.0350e+00, -2.2143e-01,  3.3704e-01,  ...,  3.1405e-01,\n",
       "            -5.4255e-01,  1.2838e+00],\n",
       "           [-1.4437e+00, -2.9796e-01,  8.1348e-01,  ..., -8.0408e-01,\n",
       "             2.3649e-02, -8.1959e-01],\n",
       "           [ 4.4314e-01, -3.5742e-01, -5.3478e-01,  ...,  1.0716e-01,\n",
       "            -8.2855e-01,  7.0502e-01],\n",
       "           ...,\n",
       "           [ 1.0854e+00,  7.1513e-01,  1.2091e+00,  ..., -9.9969e-01,\n",
       "            -2.0247e+00, -4.8723e-01],\n",
       "           [-2.4712e+00,  5.3765e-01, -2.8889e-01,  ..., -3.8338e-01,\n",
       "             1.0235e+00, -1.8646e+00],\n",
       "           [-1.1549e+00, -3.9668e-01, -1.5902e+00,  ...,  9.9944e-01,\n",
       "             4.1482e-01,  1.4203e+00]],\n",
       "  \n",
       "          [[ 2.2718e+00, -1.2994e+00, -8.6441e-01,  ..., -1.1788e+00,\n",
       "             7.9265e-01,  3.5393e-01],\n",
       "           [ 3.7257e-01, -1.7166e-01,  5.1031e-01,  ..., -2.1490e+00,\n",
       "             2.8111e+00, -8.5581e-01],\n",
       "           [-9.5836e-01, -6.4319e-01, -1.3975e+00,  ..., -6.2635e-01,\n",
       "            -5.8408e-01,  3.4565e-01],\n",
       "           ...,\n",
       "           [ 2.4768e-01, -6.4910e-01, -1.8707e+00,  ..., -1.6556e+00,\n",
       "             1.3497e-01, -1.4711e+00],\n",
       "           [-8.3703e-01, -1.4937e+00,  3.1808e-01,  ...,  1.0935e+00,\n",
       "            -2.4409e-02, -1.1546e-01],\n",
       "           [ 6.2538e-01, -1.4874e+00, -3.1578e-01,  ...,  1.8277e+00,\n",
       "            -1.1067e+00,  5.5813e-01]]]),\n",
       "  tensor([[[-1.1847e+00,  6.8695e-01, -1.0152e+00,  ..., -6.1173e-01,\n",
       "             1.5514e+00,  7.6128e-01],\n",
       "           [ 2.2801e+00,  7.6453e-01,  2.7315e-01,  ..., -4.1242e-01,\n",
       "            -1.7667e+00, -2.0119e-01],\n",
       "           [-1.3456e+00,  6.7236e-01, -2.3195e-01,  ..., -1.5406e+00,\n",
       "            -5.8220e-01, -1.1211e+00],\n",
       "           ...,\n",
       "           [ 7.2879e-01, -1.1032e+00,  8.7577e-01,  ..., -1.3916e+00,\n",
       "            -1.0353e+00, -7.0926e-01],\n",
       "           [-7.0845e-01,  1.5747e+00,  6.4036e-01,  ..., -1.5892e+00,\n",
       "            -5.0075e-01,  1.0418e-01],\n",
       "           [-1.0603e+00, -1.2515e+00,  8.6979e-01,  ..., -6.4899e-01,\n",
       "             1.5302e+00, -4.5006e-01]],\n",
       "  \n",
       "          [[ 4.7040e-02,  1.2195e-01, -1.6549e+00,  ...,  1.1142e-01,\n",
       "             1.1686e+00, -8.1714e-01],\n",
       "           [-2.6486e+00, -8.6591e-01,  4.5801e-01,  ...,  4.1079e-01,\n",
       "             1.4000e+00,  2.8311e-02],\n",
       "           [ 1.2563e+00, -9.8938e-02,  1.1666e-01,  ..., -1.0930e+00,\n",
       "            -1.0385e+00, -1.7736e+00],\n",
       "           ...,\n",
       "           [-2.2545e+00, -4.5535e-03,  1.2991e+00,  ..., -4.4985e-01,\n",
       "            -5.4313e-01, -1.7055e+00],\n",
       "           [-2.1233e+00,  1.0336e+00, -9.4529e-01,  ..., -1.2327e-01,\n",
       "             9.1958e-02, -1.2951e+00],\n",
       "           [ 8.3094e-01,  8.4578e-01,  2.0603e-01,  ..., -1.0601e+00,\n",
       "            -5.9665e-01,  3.9248e-01]],\n",
       "  \n",
       "          [[ 4.9504e-01,  8.7389e-01, -3.7656e-01,  ..., -1.0669e+00,\n",
       "            -4.4241e-01, -1.8652e+00],\n",
       "           [-1.0069e+00, -7.9181e-01,  1.0450e+00,  ..., -1.0018e+00,\n",
       "             2.3945e+00, -1.4237e+00],\n",
       "           [-8.0871e-01, -4.3816e-01, -4.5174e-01,  ..., -1.1247e+00,\n",
       "            -1.1417e+00, -1.0020e-02],\n",
       "           ...,\n",
       "           [-4.4592e-01, -1.9416e+00,  5.7453e-02,  ..., -3.2818e-01,\n",
       "             6.5988e-01, -7.1136e-01],\n",
       "           [-7.4762e-01, -7.0626e-01, -5.8865e-02,  ...,  4.9279e-01,\n",
       "            -1.2286e+00, -1.2655e-01],\n",
       "           [-5.4255e-01, -1.9482e+00,  1.0026e+00,  ..., -1.3278e+00,\n",
       "            -1.3875e-02,  3.7851e-01]],\n",
       "  \n",
       "          [[ 1.2816e+00, -7.1701e-02,  2.0335e-01,  ..., -2.6528e-02,\n",
       "            -9.6103e-01,  1.3490e+00],\n",
       "           [-1.8818e+00, -1.8715e-01,  1.1470e+00,  ..., -2.3958e-01,\n",
       "             3.2411e-01, -6.9995e-01],\n",
       "           [ 2.0575e-01, -4.3970e-01,  2.9533e-01,  ...,  7.0539e-02,\n",
       "            -8.7685e-01,  8.8467e-01],\n",
       "           ...,\n",
       "           [ 7.8125e-01,  8.7302e-01,  1.1469e+00,  ..., -4.3492e-01,\n",
       "            -2.0992e+00, -4.0003e-01],\n",
       "           [-2.6170e+00,  1.0163e+00, -2.6021e-01,  ...,  1.1911e-01,\n",
       "             1.0392e+00, -2.1064e+00],\n",
       "           [-1.0843e+00, -1.8318e-01, -1.4991e+00,  ...,  1.0146e+00,\n",
       "             2.2281e-01,  1.3325e+00]],\n",
       "  \n",
       "          [[ 1.9782e+00, -1.3297e+00, -4.5866e-01,  ..., -1.0480e+00,\n",
       "             4.2029e-01,  5.9025e-01],\n",
       "           [-6.8445e-01, -2.0235e-04,  1.0275e+00,  ..., -1.7569e+00,\n",
       "             2.4323e+00, -1.3304e+00],\n",
       "           [-1.1725e+00, -4.9273e-01, -9.2044e-01,  ..., -4.1798e-01,\n",
       "            -6.8016e-01,  3.5324e-01],\n",
       "           ...,\n",
       "           [ 2.4481e-01, -2.8389e-01, -1.8168e+00,  ..., -1.4905e+00,\n",
       "            -1.1919e-01, -1.2992e+00],\n",
       "           [-7.5217e-01, -7.5582e-01, -2.1934e-01,  ...,  1.3759e+00,\n",
       "            -3.5222e-01, -3.7040e-01],\n",
       "           [ 5.8493e-01, -1.3647e+00, -8.1247e-02,  ...,  2.1962e+00,\n",
       "            -1.2871e+00,  3.4107e-01]]])]}"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "idxs = torch.randint(1000, (5, 9))\n",
    "tgts = torch.randint(1000, (5, 9))\n",
    "transformer(idxs, tgts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data Loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils import data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LMDataLoader(data.DataLoader):\n",
    "    \"\"\"\n",
    "    Suppose batch size is 4 and our entire corpus looks like this:\n",
    "    'pytorch is an amazing deep learning framework that makes nlp really easy'\n",
    "    To take advantage of segment-level recurrence in the Transformer XL, we want to make sure that \n",
    "    the previous batch contains the previous segment at the same position. (TODO: Add better explanation)\n",
    "    \n",
    "    In other words, we want to iterate over this sentence like this\n",
    "    Batch 1: pytorch   amazing   framework nlp\n",
    "    Batch 2: is        deep      that      really\n",
    "    Batch 3: an        learning  makes     easy\n",
    "    Notice that you can reconstruct the original sentence by reading from top to bottom, left to right\n",
    "    instead of left to right, top to bottom\n",
    "    With a longer bptt (back propagation through time) length of 2 for example, the\n",
    "    minibatch would be of shape (batch_size, bptt) and would look like\n",
    "    Batch 1: pytorch   amazing   framework nlp\n",
    "             is        deep      that      really\n",
    "    Batch 2: an        learning  makes     easy\n",
    "    \"\"\"\n",
    "    def __init__(self, data: torch.LongTensor, batch_size: int, bptt: int,\n",
    "                 device=torch.device(\"cpu\")):\n",
    "        self.batch_size = batch_size\n",
    "        self.bptt = bptt\n",
    "        self.n_steps = data.size(0) // batch_size\n",
    "        \n",
    "        # we reshape the data here so that we can index\n",
    "        # efficiently into it while training\n",
    "        self.data = (data[:self.n_steps * batch_size] # trim off any elements that don't fit cleanly\n",
    "                     .view(batch_size, self.n_steps) # \n",
    "                     .transpose(0, 1) # \n",
    "                     .contiguous().to(device) # put on device as contiguous tensor\n",
    "                     )\n",
    "    \n",
    "    def __iter__(self):\n",
    "        for batch_start_idx in range(0, self.data.size(0) - 1, self.bptt):\n",
    "            batch_end_idx = min(batch_start_idx + self.bptt, self.data.size(0) - 1)\n",
    "            # TODO: What is `self.ext_len` in the original code?\n",
    "            batch_data = self.data[batch_start_idx:batch_end_idx]\n",
    "            target = self.data[batch_start_idx+1:batch_end_idx+1]\n",
    "            # we generate the sequence length as well for loss calculation later\n",
    "            yield batch_data, target, batch_end_idx - batch_start_idx\n",
    "    \n",
    "    def __len__(self):\n",
    "        return math.ceil(self.data.size(0) / self.bptt)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_corpus = torch.randint(1000, (1600, ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "BS = 16\n",
    "BPTT = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([495, 258, 302, 862, 367, 647, 251,  60, 770, 306])"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_corpus[:BPTT]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "loader = LMDataLoader(test_corpus, BS, BPTT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "b1, *_ = next(iter(loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([10, 16])"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b1.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[495, 823, 591, 439, 717, 391, 921, 848, 826, 625, 352, 577, 559, 882,\n",
       "         340, 330],\n",
       "        [258,  53, 222, 136, 839, 357, 772, 366, 135, 635, 747, 604, 829, 550,\n",
       "         609, 504],\n",
       "        [302, 762, 395, 843, 321,  45, 263, 204, 120, 937, 682,  93, 268, 999,\n",
       "         905, 529],\n",
       "        [862, 283, 565, 346, 716, 402, 875, 630, 117, 454, 756, 144, 134, 508,\n",
       "         730, 466],\n",
       "        [367, 197, 109, 294,  83, 455, 879, 790,   9, 591, 333,  71, 329, 223,\n",
       "         893,  74],\n",
       "        [647, 550,  79, 354, 476, 528, 581, 794, 682, 313,  80, 317, 840, 373,\n",
       "         299, 616],\n",
       "        [251, 321, 985,  79, 844, 498, 423, 361,  17, 264, 373, 913, 575, 623,\n",
       "         550, 332],\n",
       "        [ 60, 334,  63, 396, 135, 498, 527, 847, 365, 316, 146, 929, 327, 341,\n",
       "          98, 763],\n",
       "        [770, 921, 551, 646, 347, 515, 946, 130, 211, 605, 924, 891, 745, 970,\n",
       "         231, 421],\n",
       "        [306, 105, 416, 760, 965, 631, 320, 275, 651, 490, 723, 680, 148, 753,\n",
       "         432, 553]])"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "b1, b2, sl = next(iter(loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformer_xl = TransformerXL(1000, n_layers=4, n_heads=3, \n",
    "                               d_model=32, d_head_inner=17, d_ff_inner=71)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_weight(weight):\n",
    "    nn.init.normal_(weight, 0.0, 0.02)\n",
    "\n",
    "def init_bias(bias):\n",
    "    nn.init.constant_(bias, 0.0)\n",
    "    \n",
    "# Borrowed from the transformer XL repo\n",
    "def weights_init(m):\n",
    "    classname = m.__class__.__name__\n",
    "    if classname.find('Linear') != -1:\n",
    "        if hasattr(m, 'weight') and m.weight is not None:\n",
    "            init_weight(m.weight)\n",
    "        if hasattr(m, 'bias') and m.bias is not None:\n",
    "            init_bias(m.bias)\n",
    "    elif classname.find('Embedding') != -1:\n",
    "        if hasattr(m, 'weight'):\n",
    "            init_weight(m.weight)\n",
    "    elif classname.find('LayerNorm') != -1:\n",
    "        if hasattr(m, 'weight'):\n",
    "            nn.init.normal_(m.weight, 1.0, 0.02)\n",
    "        if hasattr(m, 'bias') and m.bias is not None:\n",
    "            init_bias(m.bias)\n",
    "    else:\n",
    "        if hasattr(m, 'u'):\n",
    "            init_weight(m.u)\n",
    "        if hasattr(m, 'v'):\n",
    "            init_weight(m.v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformer_xl.apply(weights_init);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [],
   "source": [
    "def assert_nonan_params(m):\n",
    "    for nm, param in m.named_parameters():\n",
    "        if torch.isnan(param).any():\n",
    "            raise ValueError(f\"{nm} has nan weights\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert_nonan_params(transformer_xl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'loss': tensor(6.9132, grad_fn=<NllLossBackward>),\n",
       " 'logits': tensor([[[-1.0248e-01,  1.0731e-01,  1.4011e-02,  ...,  6.9848e-02,\n",
       "            2.4726e-01,  2.2748e-01],\n",
       "          [-1.3678e-03, -4.6101e-02,  4.6903e-02,  ...,  1.5692e-01,\n",
       "           -6.4016e-02,  6.1876e-02],\n",
       "          [-4.5677e-02, -1.4621e-01,  1.3201e-02,  ...,  2.0921e-02,\n",
       "           -1.6673e-01, -6.3966e-02],\n",
       "          ...,\n",
       "          [ 1.3626e-01,  1.9220e-01, -4.0781e-02,  ..., -5.7603e-02,\n",
       "           -5.6082e-02,  4.1659e-02],\n",
       "          [-4.4718e-02,  1.5178e-02, -2.1175e-01,  ...,  1.4830e-01,\n",
       "           -7.5375e-02, -1.7004e-01],\n",
       "          [-4.6929e-02,  6.4080e-02,  3.6388e-02,  ...,  1.3949e-01,\n",
       "           -2.7204e-02, -1.4109e-01]],\n",
       " \n",
       "         [[-1.0999e-01, -1.4124e-01, -1.5175e-01,  ...,  8.1961e-02,\n",
       "           -5.8229e-02, -3.6164e-04],\n",
       "          [-1.4177e-01,  1.1944e-01, -5.5737e-02,  ..., -2.3964e-02,\n",
       "            1.8713e-01, -1.0519e-01],\n",
       "          [ 2.6806e-02,  7.5983e-02,  6.9189e-02,  ...,  8.6324e-02,\n",
       "            8.4597e-02, -2.3676e-02],\n",
       "          ...,\n",
       "          [-1.6791e-01,  8.5879e-02,  1.7494e-01,  ..., -3.3066e-02,\n",
       "           -1.5778e-01, -8.8234e-04],\n",
       "          [ 7.2104e-02,  1.2902e-01,  5.3687e-02,  ..., -1.5645e-01,\n",
       "           -7.1781e-03, -1.1068e-01],\n",
       "          [ 1.4564e-01,  3.5282e-02, -5.8816e-02,  ..., -1.1506e-01,\n",
       "            2.9478e-02,  1.8000e-01]],\n",
       " \n",
       "         [[-7.2060e-02, -8.4643e-03,  6.4367e-02,  ..., -1.3164e-01,\n",
       "           -1.3614e-01,  1.9446e-02],\n",
       "          [ 1.4124e-01,  7.9030e-02, -8.5568e-02,  ..., -1.2929e-01,\n",
       "            5.8081e-03, -8.4492e-02],\n",
       "          [-2.5493e-02,  1.2458e-01,  7.0857e-02,  ...,  3.5657e-02,\n",
       "           -8.3504e-02, -8.0807e-02],\n",
       "          ...,\n",
       "          [-3.0617e-02, -1.4736e-02, -4.8627e-02,  ...,  4.0153e-03,\n",
       "            1.6265e-01,  5.8908e-01],\n",
       "          [-9.2850e-02, -1.3584e-01,  2.3585e-01,  ..., -1.4457e-01,\n",
       "            9.7202e-03, -2.6095e-01],\n",
       "          [-8.5778e-02,  8.6147e-02,  2.1940e-02,  ...,  1.3798e-01,\n",
       "            1.7101e-02, -1.3712e-02]],\n",
       " \n",
       "         ...,\n",
       " \n",
       "         [[ 9.1503e-02, -3.7148e-02, -2.2180e-02,  ..., -7.5079e-02,\n",
       "           -2.0175e-02,  1.7088e-02],\n",
       "          [-7.6871e-02,  5.5227e-02, -1.3287e-01,  ...,  4.0191e-02,\n",
       "            1.4266e-01,  1.4642e-01],\n",
       "          [ 1.3682e-01, -9.1742e-02, -6.0162e-02,  ...,  2.3565e-02,\n",
       "           -1.3034e-01, -2.3142e-02],\n",
       "          ...,\n",
       "          [-1.6963e-01,  8.1763e-02, -3.8774e-02,  ..., -3.9358e-02,\n",
       "           -2.0138e-02, -4.5745e-03],\n",
       "          [-1.3889e-02, -4.4796e-02, -1.8235e-01,  ..., -6.2257e-03,\n",
       "            9.6230e-02, -4.3023e-02],\n",
       "          [ 6.3552e-02,  3.1558e-02, -1.2285e-01,  ...,  9.1621e-03,\n",
       "           -1.1906e-03, -8.2683e-02]],\n",
       " \n",
       "         [[ 1.3720e-01,  6.6950e-02, -7.8836e-03,  ..., -3.5002e-01,\n",
       "            7.5044e-02, -6.8420e-03],\n",
       "          [-9.4781e-02,  2.3919e-01, -2.9293e-02,  ..., -1.8158e-01,\n",
       "            1.1142e-02, -1.5246e-01],\n",
       "          [-3.8051e-02,  1.5883e-01, -2.7557e-02,  ..., -1.5475e-02,\n",
       "            7.5417e-02, -5.9191e-03],\n",
       "          ...,\n",
       "          [-9.5314e-02,  9.2642e-02, -2.6091e-01,  ..., -3.4878e-02,\n",
       "            1.1088e-01,  1.3217e-01],\n",
       "          [-8.7632e-02,  1.0678e-01,  2.4481e-01,  ..., -1.2132e-01,\n",
       "           -4.0339e-02,  1.2571e-01],\n",
       "          [-1.2485e-01,  1.1252e-01,  2.2078e-01,  ..., -1.1770e-01,\n",
       "           -8.2886e-02,  8.7156e-02]],\n",
       " \n",
       "         [[ 3.6904e-02, -3.0857e-01,  5.5182e-02,  ...,  1.7782e-03,\n",
       "           -1.5544e-01, -3.7041e-02],\n",
       "          [ 1.5587e-01, -1.8009e-02, -1.4837e-01,  ...,  1.4830e-01,\n",
       "           -8.8462e-02,  7.4064e-02],\n",
       "          [ 7.9205e-02, -5.7751e-02,  1.2568e-01,  ..., -1.5825e-01,\n",
       "            1.0375e-01,  1.2069e-01],\n",
       "          ...,\n",
       "          [ 8.7838e-02, -4.8632e-02, -1.3849e-01,  ..., -5.1900e-02,\n",
       "            2.5409e-02, -1.5170e-01],\n",
       "          [-1.6238e-01, -5.9545e-02, -2.0494e-02,  ..., -4.7279e-02,\n",
       "           -1.2305e-03,  1.3663e-01],\n",
       "          [-1.3301e-01, -1.2666e-01,  1.4380e-02,  ..., -1.0582e-01,\n",
       "            2.8081e-02,  8.3359e-02]]], grad_fn=<AddBackward0>),\n",
       " 'memory': [tensor([], size=(0, 16, 32)),\n",
       "  tensor([], size=(0, 16, 32)),\n",
       "  tensor([], size=(0, 16, 32)),\n",
       "  tensor([], size=(0, 16, 32)),\n",
       "  tensor([], size=(0, 16, 32))]}"
      ]
     },
     "execution_count": 69,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transformer_xl(b1, b2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Reference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [],
   "source": [
    "#scrap\n",
    "import mem_transformer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [],
   "source": [
    "#scrap\n",
    "def ref_init_weight(weight):\n",
    "    nn.init.normal_(weight, 0.0, 0.02)\n",
    "\n",
    "def ref_init_bias(bias):\n",
    "    nn.init.constant_(bias, 0.0)\n",
    "\n",
    "def ref_weights_init(m):\n",
    "    classname = m.__class__.__name__\n",
    "    if classname.find('Linear') != -1:\n",
    "        if hasattr(m, 'weight') and m.weight is not None:\n",
    "            ref_init_weight(m.weight)\n",
    "        if hasattr(m, 'bias') and m.bias is not None:\n",
    "            ref_init_bias(m.bias)\n",
    "    elif classname.find('AdaptiveEmbedding') != -1:\n",
    "        if hasattr(m, 'emb_projs'):\n",
    "            for i in range(len(m.emb_projs)):\n",
    "                if m.emb_projs[i] is not None:\n",
    "                    nn.init.normal_(m.emb_projs[i], 0.0, 0.01)\n",
    "    elif classname.find('Embedding') != -1:\n",
    "        if hasattr(m, 'weight'):\n",
    "            ref_init_weight(m.weight)\n",
    "    elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:\n",
    "        if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:\n",
    "            ref_init_weight(m.cluster_weight)\n",
    "        if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:\n",
    "            ref_init_bias(m.cluster_bias)\n",
    "        if hasattr(m, 'out_projs'):\n",
    "            for i in range(len(m.out_projs)):\n",
    "                if m.out_projs[i] is not None:\n",
    "                    nn.init.normal_(m.out_projs[i], 0.0, 0.01)\n",
    "    elif classname.find('LayerNorm') != -1:\n",
    "        if hasattr(m, 'weight'):\n",
    "            nn.init.normal_(m.weight, 1.0, 0.02)\n",
    "        if hasattr(m, 'bias') and m.bias is not None:\n",
    "            ref_init_bias(m.bias)\n",
    "    elif classname.find('TransformerLM') != -1:\n",
    "        if hasattr(m, 'r_emb'):\n",
    "            ref_init_weight(m.r_emb)\n",
    "        if hasattr(m, 'r_w_bias'):\n",
    "            ref_init_weight(m.r_w_bias)\n",
    "        if hasattr(m, 'r_r_bias'):\n",
    "            ref_init_weight(m.r_r_bias)\n",
    "        if hasattr(m, 'r_bias'):\n",
    "            ref_init_bias(m.r_bias)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [],
   "source": [
    "#scrap\n",
    "from mem_transformer import MemTransformerLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [],
   "source": [
    "#scrap\n",
    "ref_transformer_xl = MemTransformerLM(1000, n_layer=4, n_head=3,\n",
    "                                      d_model=32, d_head=16, d_inner=71,\n",
    "                                      dropout=0.1, dropatt=0.0,\n",
    "                                      ext_len=0, tgt_len=BPTT, mem_len=0,\n",
    "                                      d_embed=None,\n",
    "                                     )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [],
   "source": [
    "#scrap\n",
    "ref_transformer_xl.apply(ref_weights_init);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor([[6.9353, 6.7244, 7.0216, 7.1274, 6.8613, 6.8446, 6.7964, 6.8400, 6.9778,\n",
       "          7.0077, 7.0511, 7.0855, 6.9384, 6.9447, 7.1003, 6.8852],\n",
       "         [6.7513, 6.9497, 6.9046, 6.9075, 7.0219, 6.7108, 6.6179, 6.9633, 6.7379,\n",
       "          7.1386, 6.9294, 6.9583, 6.9007, 6.9911, 6.8230, 7.0799],\n",
       "         [6.8527, 6.8001, 6.8433, 6.9002, 6.7208, 6.8161, 6.8720, 6.7424, 6.9789,\n",
       "          6.7125, 7.1070, 7.2470, 6.7621, 6.9529, 6.9121, 6.6699],\n",
       "         [6.9093, 7.1417, 6.9384, 6.5931, 6.9981, 7.0698, 6.8795, 6.9889, 6.8435,\n",
       "          6.7313, 6.9985, 6.9949, 6.9754, 6.9342, 6.9275, 6.6794],\n",
       "         [6.8482, 7.0016, 6.8612, 6.8453, 6.7217, 6.7929, 6.8630, 7.0411, 6.7353,\n",
       "          6.8853, 6.8997, 7.0121, 7.0002, 6.8065, 6.8801, 6.9877],\n",
       "         [6.7444, 6.9660, 6.9177, 7.0074, 6.9557, 6.9720, 6.8851, 6.9806, 6.9515,\n",
       "          6.9527, 6.9473, 6.7841, 6.8383, 6.7471, 6.8321, 6.7420],\n",
       "         [6.8317, 6.7776, 7.0481, 6.8136, 6.9150, 6.4084, 6.8216, 6.8444, 6.7454,\n",
       "          6.8988, 6.9479, 6.7157, 7.1792, 6.9113, 7.0349, 6.7164],\n",
       "         [7.0618, 6.9694, 6.7682, 6.8886, 6.9027, 6.7948, 6.6863, 6.8170, 6.6824,\n",
       "          6.8915, 7.1446, 6.8980, 7.0670, 6.8849, 6.8771, 6.9164],\n",
       "         [7.0025, 6.9980, 6.9389, 6.7054, 6.8321, 6.9733, 6.8588, 6.7756, 6.9654,\n",
       "          6.6358, 6.7768, 6.7941, 6.9083, 7.0183, 6.9262, 6.8640],\n",
       "         [6.7498, 7.2803, 6.7334, 6.7635, 6.6662, 7.0329, 7.0988, 6.8435, 6.9043,\n",
       "          6.7753, 6.8967, 6.8992, 6.9940, 7.0396, 6.8362, 6.9354]],\n",
       "        grad_fn=<ViewBackward>)]"
      ]
     },
     "execution_count": 75,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#scrap\n",
    "ref_transformer_xl(b1, b2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[6.8000, 6.6365, 7.0200, 7.0980, 6.6861, 6.8559, 6.8261, 6.8237, 6.9635,\n",
       "         7.0795, 7.0688, 7.0735, 6.9621, 6.9774, 7.0613, 6.8890],\n",
       "        [6.6851, 7.0286, 6.8026, 6.9284, 7.0279, 6.6342, 6.6700, 6.9929, 6.7816,\n",
       "         7.1479, 6.8753, 6.8635, 6.9053, 6.9826, 6.8304, 7.1957],\n",
       "        [6.8265, 6.8260, 6.9705, 7.0031, 6.8390, 6.7675, 6.8790, 6.8100, 6.8963,\n",
       "         6.7315, 7.0904, 7.1698, 6.8904, 7.0204, 6.9331, 6.5868],\n",
       "        [6.9183, 7.0694, 6.9417, 6.6238, 6.9771, 7.1391, 6.8853, 7.0403, 6.8509,\n",
       "         6.7267, 7.0064, 7.0531, 7.0081, 6.9838, 6.9317, 6.6311],\n",
       "        [6.8915, 6.9477, 6.8655, 6.8129, 6.8731, 6.7885, 6.7684, 6.9907, 6.7042,\n",
       "         6.8861, 6.8179, 6.9643, 7.0352, 6.8217, 6.8161, 7.0581],\n",
       "        [6.8980, 6.9834, 7.0165, 6.9784, 6.8526, 6.9986, 6.8698, 7.0037, 6.9837,\n",
       "         7.0364, 6.9384, 6.8588, 6.8577, 6.7702, 6.8438, 6.8123],\n",
       "        [6.9474, 6.8560, 6.9574, 6.8566, 6.9254, 6.4388, 6.7687, 6.8959, 6.7909,\n",
       "         7.0023, 6.9612, 6.7797, 7.0518, 6.7482, 7.0498, 6.7730],\n",
       "        [7.0244, 6.9452, 6.7149, 6.9359, 6.8671, 6.8249, 6.7241, 6.8712, 6.6352,\n",
       "         6.9809, 7.2391, 6.9532, 7.0553, 6.8603, 6.8672, 6.9049],\n",
       "        [6.9520, 7.0617, 6.8577, 6.7798, 6.8352, 6.9187, 6.8557, 6.8760, 6.9344,\n",
       "         6.5798, 6.7829, 6.6915, 6.9817, 7.1421, 6.9493, 6.9642],\n",
       "        [6.7361, 7.1144, 6.7754, 6.8073, 6.7350, 6.9979, 7.1436, 6.9174, 6.9389,\n",
       "         6.8276, 6.8532, 6.9089, 7.0572, 7.0861, 6.7977, 6.8606]],\n",
       "       grad_fn=<ViewBackward>)"
      ]
     },
     "execution_count": 76,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#scrap\n",
    "ref_transformer_xl(b1, b2)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(6.8989, grad_fn=<MeanBackward1>)"
      ]
     },
     "execution_count": 77,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#scrap\n",
    "ref_transformer_xl(b1, b2)[0].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Actual Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: Use some better library\n",
    "class Config(dict):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        for k, v in kwargs.items():\n",
    "            setattr(self, k, v)\n",
    "    \n",
    "    def set(self, key, val):\n",
    "        self[key] = val\n",
    "        setattr(self, key, val)\n",
    "        \n",
    "    def update(self, dct):\n",
    "        for k, v in dct.items():\n",
    "            self.set(k, v)\n",
    "\n",
    "# We will use prime numbers to ensure our implementation\n",
    "# is actually correct\n",
    "config = Config(\n",
    "    use_fp16=False,\n",
    "    seed=101,\n",
    "    debug=False,\n",
    "    is_ref=False,\n",
    "    warmup_step=0,\n",
    "    # Check default params\n",
    "    min_lr=0., \n",
    "    dropouta=0.,\n",
    "    clip=0.25,\n",
    "    log_interval=200,\n",
    "    eval_interval=50,\n",
    ")\n",
    "\n",
    "if TESTING:\n",
    "    config.update(dict(\n",
    "        debug=True,\n",
    "        lr=0.00025,\n",
    "        bs=8,\n",
    "        epochs=2,\n",
    "        max_step=10000, # shorten for testing\n",
    "        n_layers=4,\n",
    "        n_heads=3,\n",
    "        d_model=32,\n",
    "        d_head_inner=17,\n",
    "        d_ff_inner=71,\n",
    "        dropout=0.1,\n",
    "        train_bptt=33,\n",
    "        eval_bptt=41,\n",
    "        mem_len=41,\n",
    "        eval_mem_len=63,\n",
    "    ))\n",
    "else:\n",
    "    config.update(dict(\n",
    "        lr=0.0025,\n",
    "        bs=22,\n",
    "        epochs=2,\n",
    "        max_step=400000,\n",
    "        n_layers=12,\n",
    "        n_heads=8,\n",
    "        d_model=512,\n",
    "        d_head_inner=64,\n",
    "        d_ff_inner=2048,\n",
    "        dropout=0.1,\n",
    "        train_bptt=512,\n",
    "        eval_bptt=128,\n",
    "        mem_len=512,\n",
    "        eval_mem_len=2100,\n",
    "    ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x11da2a7f0>"
      ]
     },
     "execution_count": 79,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.manual_seed(config.seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "TODO: Implement ourselves?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab = Vocab(special=[\"<eos>\"], lower_case=True)\n",
    "\n",
    "vocab.count_file(DATA_DIR / \"train.txt\")\n",
    "vocab.count_file(DATA_DIR / \"valid.txt\")\n",
    "vocab.count_file(DATA_DIR / \"test.txt\")\n",
    "None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "building vocab with min_freq=0, max_size=None\n",
      "final vocab size 10000 from 9999 unique tokens\n"
     ]
    }
   ],
   "source": [
    "vocab.build_vocab()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = vocab.encode_file(DATA_DIR / \"train.txt\", ordered=True, add_eos=True)\n",
    "valid_dataset = vocab.encode_file(DATA_DIR / \"valid.txt\", ordered=True, add_eos=True)\n",
    "test_dataset = vocab.encode_file(DATA_DIR / \"test.txt\", ordered=True, add_eos=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([6503, 6151, 7924, 8539, 2353, 8540, 6918, 8541, 8542, 7394, 7925, 7926,\n",
       "        6152, 8543, 6504, 6919, 6920, 8544, 5560, 6153, 8545, 8546, 8547, 7927,\n",
       "           0, 9231,    2,    3,   73,  399,   34, 2136,    1,  146,   19,    6,\n",
       "        9232,  282,  450,    3,    0,   23,    2,   13,  142,    4,    2, 5090,\n",
       "           1, 2952])"
      ]
     },
     "execution_count": 83,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset[:50]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Prepare iterators"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_iter = LMDataLoader(train_dataset, config.bs, config.train_bptt, device=device)\n",
    "valid_iter = LMDataLoader(valid_dataset, config.bs, config.eval_bptt, device=device)\n",
    "test_iter = LMDataLoader(test_dataset, config.bs, config.eval_bptt, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[6503,    6,    1,    0,    5, 5485,  542,    2],\n",
       "         [6151,    2,  327,  909,    6,    7,   16,    9],\n",
       "         [7924,  225, 6798,    2, 2069,  410,  880,  157],\n",
       "         [8539,   72,  329,  380,    1,   15,    3,  299],\n",
       "         [2353,    0, 1538, 1635,  104,  495,    3,  232],\n",
       "         [8540,   29, 9744,   32,   90, 5086,    0,    6],\n",
       "         [6918,   84,   60,    2, 2392,    0,   14,    2],\n",
       "         [8541,   27, 1437,  506,    1,  112,   24,    2],\n",
       "         [8542, 2298,  105,  557,  320,  553,  195, 3524],\n",
       "         [7394,    0, 2262,   16, 1715,  497,  150,    0],\n",
       "         [7925,  495,    2,    1,    4,   68, 2440,   64],\n",
       "         [7926,    2,  885, 2978, 1046,  112,   11,   47],\n",
       "         [6152,   45,    2,    0,    8,  918,    1, 3464],\n",
       "         [8543, 1973, 5616,   39,  723,   25,  490,   14],\n",
       "         [6504,    2,    4,   13,    1, 3553,   75,  568],\n",
       "         [6919,    6,  630,  784,  284,    5, 2369,    1],\n",
       "         [6920, 2664,    0, 6171,  819,   25,    1,  569],\n",
       "         [8544,   10,   29,   20,    0,    1,    2,    4],\n",
       "         [5560,    2,   39, 2841, 8138,    2,    8,    1],\n",
       "         [6153,   21,    2,  310, 2066,   36,    2,  372],\n",
       "         [8545, 4648,    1,    3,  573,    1,    4,  160],\n",
       "         [8546,    2,   53,   18,   33, 5177, 2947,   15],\n",
       "         [8547,    0,  779,  351, 2753,   54, 1118,    2],\n",
       "         [7927,  215, 1099,  682,   18,    0,    9,    2],\n",
       "         [   0, 1193, 7123,    4,   12,   64,    2,    2],\n",
       "         [9231, 1135,    1,  156,    3,   47,    8, 1683],\n",
       "         [   2,    7, 5232,  996,   49, 5177,    4,    4],\n",
       "         [   3, 1573,    4, 3098,  151,    0,   51,  157],\n",
       "         [  73,    2,    2,    2,    3,  787,    2,  299],\n",
       "         [ 399, 4448,    0,    5,  122,  674, 4057,  232],\n",
       "         [  34,    0,   23,   32, 1226,    2,   21,    0],\n",
       "         [2136,    2, 1575,    2,  361,   72,    6,   14],\n",
       "         [   1,  586,  329, 4467,  573,    4, 2561,    9]]),\n",
       " tensor([[6151,    2,  327,  909,    6,    7,   16,    9],\n",
       "         [7924,  225, 6798,    2, 2069,  410,  880,  157],\n",
       "         [8539,   72,  329,  380,    1,   15,    3,  299],\n",
       "         [2353,    0, 1538, 1635,  104,  495,    3,  232],\n",
       "         [8540,   29, 9744,   32,   90, 5086,    0,    6],\n",
       "         [6918,   84,   60,    2, 2392,    0,   14,    2],\n",
       "         [8541,   27, 1437,  506,    1,  112,   24,    2],\n",
       "         [8542, 2298,  105,  557,  320,  553,  195, 3524],\n",
       "         [7394,    0, 2262,   16, 1715,  497,  150,    0],\n",
       "         [7925,  495,    2,    1,    4,   68, 2440,   64],\n",
       "         [7926,    2,  885, 2978, 1046,  112,   11,   47],\n",
       "         [6152,   45,    2,    0,    8,  918,    1, 3464],\n",
       "         [8543, 1973, 5616,   39,  723,   25,  490,   14],\n",
       "         [6504,    2,    4,   13,    1, 3553,   75,  568],\n",
       "         [6919,    6,  630,  784,  284,    5, 2369,    1],\n",
       "         [6920, 2664,    0, 6171,  819,   25,    1,  569],\n",
       "         [8544,   10,   29,   20,    0,    1,    2,    4],\n",
       "         [5560,    2,   39, 2841, 8138,    2,    8,    1],\n",
       "         [6153,   21,    2,  310, 2066,   36,    2,  372],\n",
       "         [8545, 4648,    1,    3,  573,    1,    4,  160],\n",
       "         [8546,    2,   53,   18,   33, 5177, 2947,   15],\n",
       "         [8547,    0,  779,  351, 2753,   54, 1118,    2],\n",
       "         [7927,  215, 1099,  682,   18,    0,    9,    2],\n",
       "         [   0, 1193, 7123,    4,   12,   64,    2,    2],\n",
       "         [9231, 1135,    1,  156,    3,   47,    8, 1683],\n",
       "         [   2,    7, 5232,  996,   49, 5177,    4,    4],\n",
       "         [   3, 1573,    4, 3098,  151,    0,   51,  157],\n",
       "         [  73,    2,    2,    2,    3,  787,    2,  299],\n",
       "         [ 399, 4448,    0,    5,  122,  674, 4057,  232],\n",
       "         [  34,    0,   23,   32, 1226,    2,   21,    0],\n",
       "         [2136,    2, 1575,    2,  361,   72,    6,   14],\n",
       "         [   1,  586,  329, 4467,  573,    4, 2561,    9],\n",
       "         [ 146,  158,   63,   17,   33,    1,    4,    1]]),\n",
       " 33)"
      ]
     },
     "execution_count": 85,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next(iter(train_iter))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training Loop"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "TODO: With FP16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [],
   "source": [
    "def logging(x): print(x) # temporary!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.optim as optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import time\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "\n",
    "# TODO: Rewrite to use ignite or some cleaner framework\n",
    "loss_change = []\n",
    "val_loss_change = []\n",
    "\n",
    "def train_epoch(\n",
    "    epoch: int,\n",
    "    model: nn.Module, train_loader: data.DataLoader, \n",
    "    val_loader: data.DataLoader,\n",
    "    optimizer: optim.Optimizer,\n",
    "    scheduler,\n",
    "    train_step_start=0.,\n",
    " ):\n",
    "    # Turn on training mode which enables dropout.\n",
    "    model.train()\n",
    "    mems = tuple() if config.is_ref else None\n",
    "    train_step = train_step_start\n",
    "    train_loss = 0\n",
    "    log_start_time = time.time()\n",
    "    best_val_loss = float(\"inf\")\n",
    "    \n",
    "    pbar = tqdm(train_loader, total=min(config.max_step - train_step_start, len(train_loader)))\n",
    "    for batch_idx, (data, target, seq_len) in enumerate(pbar):\n",
    "        model.zero_grad()\n",
    "        if config.is_ref:\n",
    "            ret = model(data, target, *mems)\n",
    "            loss, mems = ret[0], ret[1:]\n",
    "            loss = loss.mean()\n",
    "        else:\n",
    "            out_dict = model(data, target, memory=mems)\n",
    "            loss, mems = out_dict[\"loss\"], out_dict[\"memory\"]\n",
    "\n",
    "        if config.use_fp16:\n",
    "            optimizer.backward(loss)\n",
    "        else:\n",
    "            loss.backward()\n",
    "        train_loss += loss.item()\n",
    "        loss_change.append(loss.item())\n",
    "\n",
    "        if config.use_fp16:\n",
    "            optimizer.clip_master_grads(config.clip)\n",
    "        else:\n",
    "            torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip)\n",
    "\n",
    "        optimizer.step()\n",
    "        assert_nonan_params(model) # check for nan\n",
    "        \n",
    "        # step-wise learning rate annealing\n",
    "        train_step += 1\n",
    "        # linear warmup stage\n",
    "        if train_step < config.warmup_step:\n",
    "            curr_lr = config.lr * train_step / config.warmup_step\n",
    "            optimizer.param_groups[0]['lr'] = curr_lr\n",
    "        else:\n",
    "            scheduler.step(train_step)\n",
    "            \n",
    "        if train_step % config.log_interval == 0:\n",
    "            cur_loss = train_loss / config.log_interval\n",
    "            elapsed = time.time() - log_start_time\n",
    "            log_str = '| epoch {:3d} step {:>8d} | lr {:.3g} ' \\\n",
    "                      '| loss {:5.2f}'.format(\n",
    "                epoch, train_step, optimizer.param_groups[0]['lr'], cur_loss)\n",
    "            log_str += ' | ppl {:9.3f}'.format(math.exp(cur_loss))\n",
    "#             logging(log_str)\n",
    "            pbar.set_description(log_str)\n",
    "            train_loss = 0\n",
    "            log_start_time = time.time()\n",
    "\n",
    "        if train_step % config.eval_interval == 0:\n",
    "            val_loss = evaluate(model, val_loader)\n",
    "            val_loss_change.append(val_loss)\n",
    "            # TODO: Log appropriately\n",
    "#             logging('-' * 100)\n",
    "#             log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \\\n",
    "#                       '| valid loss {:5.2f}'.format(\n",
    "#                 train_step // config.eval_interval, train_step,\n",
    "#                 (time.time() - eval_start_time), val_loss)\n",
    "#             log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss))\n",
    "#             logging(log_str)\n",
    "#             logging('-' * 100)\n",
    "            # Save the model if the validation loss is the best we've seen so far.\n",
    "            if not best_val_loss or val_loss < best_val_loss:\n",
    "#                 if not config.debug:\n",
    "#                     with open(os.path.join(args.work_dir, 'model.pt'), 'wb') as f:\n",
    "#                         torch.save(model, f)\n",
    "#                     with open(os.path.join(args.work_dir, 'optimizer.pt'), 'wb') as f:\n",
    "#                         torch.save(optimizer.state_dict(), f)\n",
    "                best_val_loss = val_loss\n",
    "\n",
    "            eval_start_time = time.time()\n",
    "\n",
    "        if train_step == config.max_step:\n",
    "            return train_step\n",
    "    return train_step"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(model: nn.Module, val_loader: data.DataLoader):\n",
    "    # Turn on evaluation mode which disables dropout.\n",
    "    model.eval()\n",
    "\n",
    "    # If the model does not use memory at all, make the ext_len longer.\n",
    "    # Otherwise, make the mem_len longer and keep the ext_len the same.\n",
    "    model.reset_length(config.eval_bptt,\n",
    "        0, config.eval_mem_len+config.train_bptt-config.eval_bptt)\n",
    "\n",
    "    # Evaluation\n",
    "    total_len, total_loss = 0, 0.\n",
    "    with torch.no_grad():\n",
    "        mems = tuple() if config.is_ref else None\n",
    "        for i, (data, target, seq_len) in enumerate(val_loader):\n",
    "            if config.is_ref:\n",
    "                ret = model(data, target, *mems)\n",
    "                loss, mems = ret[0], ret[1:]\n",
    "                loss = loss.mean()\n",
    "            else:\n",
    "                out_dict = model(data, target, memory=mems)\n",
    "                loss, mems = out_dict[\"loss\"], out_dict[\"memory\"]\n",
    "            total_loss += seq_len * loss.float().item()\n",
    "            total_len += seq_len\n",
    "\n",
    "    # Switch back to the training mode\n",
    "    model.reset_length(config.train_bptt, 0, config.mem_len)\n",
    "    model.train()\n",
    "\n",
    "    return total_loss / total_len"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, train_loader, valid_loader):\n",
    "    optimizer = optim.Adam(model.parameters(), lr=config.lr)\n",
    "    total_steps = min(config.max_step, len(train_loader) * config.epochs)\n",
    "    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,\n",
    "                    total_steps, eta_min=config.min_lr)\n",
    "    train_step_start = 0\n",
    "    for epoch in range(config.epochs):\n",
    "        if train_step_start >= config.max_step:\n",
    "            break\n",
    "        train_step_start = train_epoch(\n",
    "            epoch,\n",
    "            model,\n",
    "            train_iter,\n",
    "            valid_iter,\n",
    "            optimizer,\n",
    "            scheduler,\n",
    "            train_step_start,\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_final(model, val_loader):\n",
    "    # Turn on evaluation mode which disables dropout.\n",
    "    model.eval()\n",
    "    total_len, total_loss = 0, 0.\n",
    "    start_time = time.time()\n",
    "    \n",
    "    model.reset_length(config.eval_bptt, 0, config.eval_mem_len + config.train_bptt - config.eval_bptt)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        mems = tuple() if config.is_ref else None\n",
    "        for i, (data, target, seq_len) in enumerate(val_loader):\n",
    "            if config.is_ref:\n",
    "                ret = model(data, target, *mems)\n",
    "                loss, mems = ret[0], ret[1:]\n",
    "                loss = loss.mean()\n",
    "            else:\n",
    "                out_dict = model(data, target, memory=mems)\n",
    "                loss, mems = out_dict[\"loss\"], out_dict[\"memory\"]\n",
    "            total_loss += seq_len * loss.item()\n",
    "            total_len += seq_len\n",
    "        total_time = time.time() - start_time\n",
    "    \n",
    "    model.reset_length(config.train_bptt, 0, config.mem_len)\n",
    "    loss_val = total_loss / total_len\n",
    "    return {\"loss\": loss_val, \"ppl\": math.exp(loss_val)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformer_xl = TransformerXL(\n",
    "    num_embeddings=len(vocab), n_layers=config.n_layers,\n",
    "    n_heads=config.n_heads, d_model=config.d_model,\n",
    "    d_head_inner=config.d_head_inner, \n",
    "    d_ff_inner=config.d_ff_inner,\n",
    "    dropout=config.dropout,\n",
    "    dropouta=config.dropouta,\n",
    "    seq_len=config.train_bptt,\n",
    "    mem_len=config.mem_len,\n",
    ")\n",
    "if torch.cuda.is_available(): transformer_xl.cuda()\n",
    "transformer_xl.apply(weights_init);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert_nonan_params(transformer_xl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prod(x):\n",
    "    acc = 1\n",
    "    for v in x: acc *= v\n",
    "    return acc\n",
    "\n",
    "def num_params(model: nn.Module):\n",
    "    acc = 0\n",
    "    for p in model.parameters():\n",
    "        acc += prod(p.shape)\n",
    "    return acc\n",
    "\n",
    "def num_params_per_param(model):\n",
    "    d = {}\n",
    "    for name, p in model.named_parameters():\n",
    "        d[name] = prod(p.shape)\n",
    "    return d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "381842"
      ]
     },
     "execution_count": 95,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "num_params(transformer_xl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "| epoch   0 step     3400 | lr 0.000132 | loss  6.15 | ppl   469.951: 100%|██████████| 3522/3522 [17:58<00:00,  3.27it/s]  \n",
      "| epoch   1 step     7000 | lr 2.41e-08 | loss  6.07 | ppl   432.002: 100%|██████████| 3522/3522 [17:23<00:00,  3.37it/s]  \n"
     ]
    }
   ],
   "source": [
    "train(\n",
    "    transformer_xl,\n",
    "    train_iter,\n",
    "    valid_iter,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'loss': 5.9934049364361615, 'ppl': 400.7769092407092}"
      ]
     },
     "execution_count": 97,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluate_final(transformer_xl, valid_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_change_self = [x for x in loss_change]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x125553cc0>]"
      ]
     },
     "execution_count": 100,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAAD8CAYAAABXe05zAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzt3Xd8FGX+B/DPk05C6KGX0KvUiKDSm4CnWA89y3l3ep56Fs7TcOphl1PPs57I2c6figX1LBERkC7SW+gtQKihJUAIac/vj53ZbJnZnc3u7MxsPu/Xixeb2dmdb5LNd575zlOElBJEROQccVYHQEREoWHiJiJyGCZuIiKHYeImInIYJm4iIodh4iYichgmbiIih2HiJiJyGCZuIiKHSTDjTRs1aiQzMzPNeGsiopi0evXqY1LKDCP7mpK4MzMzsWrVKjPemogoJgkh9hrdl6USIiKHYeImInIYJm4iIodh4iYichgmbiIih2HiJiJyGCZuIiKHsVXiXrS9AHuPn7U6DCIiWzNlAE513fLuCgBA3tTxFkdCRGRftmpxqz5duc/qEIiIbMuWifvhLzZaHQIRkW3ZKnFvfnKM+7GU0sJIiIjsy1aJOzWpquS+bPdxCyMhIrIvWyVuAHhgZCcAwHtL86wNhIjIpmyXuMf3bAYAmLP5iMWREBHZk+0Sd8O0JKtDICKyNUOJWwhxnxAiVwixSQhxv5kBpSbHm/n2RESOFzRxCyF6ALgdQH8AvQBcLoToYFZAyQmuxN26QapZhyAicjQjLe6uAJZLKYullOUAFgK42tSgBLDvRDHKKirNPAwRkSMZSdy5AAYJIRoKIVIBjAPQyncnIcQdQohVQohVBQUFYQVVqXThPnm2NKz3ISKKRUETt5RyC4B/APgRwA8A1gGo0NhvupQyS0qZlZFhaKFiIiKqBkM3J6WU70gp+0kpBwM4CWC7uWEpx43GQYiIHMZor5LGyv+t4apvf2xmUP+45gIAQEmZX8OeiKjGMzqt6xdCiIYAygDcLaU8ZWJMSE9JBACcY+ImIvJjKHFLKQeZHYinWkmuLoHFpUzcRES+bDdyEgD2nygGALw6b4fFkRAR2Y8tE7da2849UGhxJERE9mPLxD2hdwsAwJ1D2lscCRGR/dgycavD3oUQFkdCRGQ/9kzcia6wnvpus8WREBHZjy0Td1K8LcMiIrIFW2bIuDiWSIiI9NgycRMRkT4mbiIih2HiJiJyGNsm7juHtOdNSiIiDbbNjDuPnkZpRaV7+DsREbnYNnHP3XIUAPDN+oMWR0JEZC+2TdyqI0UlVodARGQrtk/c6SlGpwwnIqoZbJu4n7+mJwBgnlIyISIiF9sm7oz0ZADA1sOnLY6EiMhebJu4uzarY3UIRES2ZNvE3bRuCgCgXaM0iyMhIrIXW9/56968DpopCZyIiFxs2+IGgMT4OJRWSKvDICKyFZsnboGy8kqrwyAishWbJ+44lFUwcRMRebJ1jXtl3gmUsVRCROTF1i1uNWmfPFtqcSRERPZh68StmvzlRqtDICKyDUck7h82HbY6BCIi23BE4iYioipM3EREDmPrxD3xwlZWh0BEZDu2TtzPXX2B1SEQEdmOrRO3EML9WEr25yYiAmyeuD1tyC+0OgQiIltwTOIur+TQdyIiwAGJu0vTdABA7oEiiyMhIrIH2yfu+0d2AgBM+WaTxZEQEdmD7RP3ubJyq0MgIrIV2yduIiLyZihxCyEeEEJsEkLkCiFmCCGitp4Yp+MmIvIWNHELIVoAuBdAlpSyB4B4ABPNDkxV4dGbZOfRM9E6LBGRbRktlSQAqCWESACQCuCgeSF569emvvtxSVlFtA5LRGRbQRO3lPIAgBcB7ANwCEChlPJHswNTdWic7n588NS5aB2WiMi2jJRK6gO4EkBbAM0BpAkhbtLY7w4hxCohxKqCgoLIRwrgm/VRa+gTEdmWkVLJSAB7pJQFUsoyAF8CuNh3JynldClllpQyKyMjI9JxAvCeu4SIqKYykrj3ARgghEgVrsw5AsAWc8PS9u36gygsLrPi0EREtmGkxr0cwEwAawBsVF4z3eS4dPV6MmrldSIiWzLUq0RKOUVK2UVK2UNKebOU8rzZgXn66S9Donk4IiJbc8TIydrJCVaHQERkG45I3AnxjgiTiCgqHJERE+PZm4SISOWQxO2IMImIosIRGZGJm4ioiiMyYnycd6mkqIR9uYmo5nJE4vb15LebrQ6BiMgyjkzca/aetDoEIiLLOCZxJ3iUS3YfO2thJERE1nJM4v7f3ZdYHQIRkS04JnH3aFEXl3VvanUYRESWc0ziBoDfXdrW6hCIiCznqMR9qrjU/biMqwgTUQ3lqMTt2Z/7xR+3WRgJEZF1HJW4Pe08whXfiahmclTi5splREQOS9wN0pLdj7ccKsL58goLoyEisoajEnfvVvXcjw8WlmDw8/MtjIaIyBqOSty+jhRFdQU1IiJbcHTiJiKqiRyXuC/t0MjqEIiILOW4xH1dVkurQyAispTjEvcVvZpbHQIRkaUcl7iFT2fuL1bnWxQJEZE1HJe4ff3l8/VWh0BEFFWOT9xERDUNEzcRkcMwcRMROQwTNxGRw8RE4v54+T6rQyAiihpHJu7/+31/r6//9tVGiyIhIoo+Rybuvq3rWx0CEZFlHJm405IT/LZlZudYEAkRUfQ5MnETEdVkMZW4czYcsjoEIiLTxVTivveTtVaHQERkuphK3BWVElJKq8MgIjJVTCVuAHh13k6rQyAiMpVjE/e2py/T3P7pSg7GIaLYFjRxCyE6CyHWefwrEkLcH43gAklOiNfcXhlmpWTR9gKcKi4N701CVFkpURlu4ERUYwRN3FLKbVLK3lLK3gD6ASgG8JXpkRlwTV//Zcwkqp8AS8srccu7KzBx+i/hhBWy9o98j3GvLo7qMYnIuUItlYwAsEtKudeMYEL12OVd/baFc2+yUnnx1sOnceDUueq/UYikdB2TiMiIUBP3RAAzzAikOnyXMQOAo6fPR6RnyaX/+MnwvjdM/wXfrD8Y9jGJiIwwnLiFEEkArgDwuc7zdwghVgkhVhUUFEQqvoDi4/wTN+BqvZ4vr9B93VPfbcbzP2wN+N6h5P5lu4/j3hlrsTG/EF+vO+Denpmdg8e/2WT8jYiIDAilxT0WwBop5RGtJ6WU06WUWVLKrIyMjMhEF4RO3sbYVxZj8hcbcfZ8OXIPFOL4mfNeN//eWbIH/16wCwCweu9JFBaXAQivzAIAd364Gvd9ss5r2/s/54X3pgHsPHoaV76+BKdLykw7BhHZj/9sTfpugI3KJAAQp1EqUX259gDmbDmC0yXl7m15U8d77VNZKXHNmz+jV6t6+PruS8KOJ5p1cQB4cfZ2rM8vxJIdxzD2gmZRPTYRWcdQi1sIkQZgFIAvzQ0nNIESNwCvpK1FbWBvzD+lfO2sLnlG4v0h9xAys3OiflIhIvMYStxSyrNSyoZSykKzAwqFXqlEz6aD2uHrpb8+T/6IkjL/WvnX6w5gzb6ToR3cIjNX5wMANh8ssjgSIooUx46cBPRvTuoZ/+oSr6+3HPJOZr417pPFZcg/6d9Sve+Tdbj63z8HPJaaMPUcLizBtIW7sDLvRMD9wlFRKTF3y1EAcNQcLmfPlzsqXqJoc3Ti1uoOGMx/PW4WXv6aK5EHyhH3fLwGmdk5mi1szx4kvh78fL3m/mfOl2PpzmMY8Nw8TJ21FddNW2Y8eACF58rw/tI9hhLbB8vy3I+dkgYLTp9H9ymz3TePicifoxM3AKx5bFRI+0/R6Z53tKhEM7mpA2NmbzoMAMg7dtb93H2frDPc3W9jfiHu+2Qd/vblRmw8UP2K06RP1+HxbzdjZZ7/iaSiUuK7DQfdSf2Jbze7n/PM89Fuze45dhYHDdbYjxSVAAhtbvWVeSdQWl5ZrdiInMjxibtBWhL6ZzYI+32em7UVJ87oz1FSUHQeAPDFGu8SiFZ3v/eW7vHbdua860apmpiqa95WV+nDs5+6euHx3tI9uOfjtfhqrf+VQP7JYkgp8eaCXWg7+XvkhnHyCNWwFxfg4qnGBzT5+nrdAZw8q/272Xq4CNdNW4Znv98S8D1Kyiow9IX5WLbreLXjILILxyduoGqoeji+WnsAg1+YH5FjebZ0AaC8ohI7C84AcCXZhduMD1CSUmLZruM4VVyKj5cHnvnwcKHrpHDszHnN5z9duR//UAYemVFb35B/Cv+asz2i77n/RDHu+2Qd7pmxRvP5E0pC33o48M3XHUfOIO94MZ75fnPA/YicIJR+3LZVEaVL/89W7cf5stAvyTs8Msvr62W7jbX63pi/E9+uPxh0HpMSJaZgJX8z50PZfuQ0rnh9KQDggVGdIva+pRWu7+3QqcBXKkY/ArznSbEgJhJ3NGZEzdl4CF9qlCBCVRJC4n9h9raAz2855ErEU77ZhAl9WoQUR6QT2Oh/LYro+x097UrU6rlI70pHwNgN6mrcx44ZJ8+WovBcGTIbpVkdCkVITJRKBnVoZPoxzkfo5te6/ad0n7vi9SWYs/kIMrNzgs4JnnfsLPadKAbg6mlyvrwCPyn174NBWqeeth0+jVfn7TC8vxElZRU4c74c415ZjIdnbnBvbzc5B796bUmAV1Y5ptxvUAdZhXueqckt7eH/XIChLy6wOoyYt/VwkbvBYbaYSNyRvDS30ob8Qtz+wSoA+r1fVI997f38szlbsKvA1ePFyPwoagv0umk/46U521FSVoFfdh/HZyv3B31t7oFCZGbnYMmOY5rPj31lMXpMmY3Nh4rw6aqq96uUCLlHjZq4g91b8Hw290AhXp67HZsPFmHSp+tQYfCSTEpp+rwve4+fxdEwb1CH6mQx57Lx9P3GQ7qD8cJx2cuLMeT5BRF/Xy0xkbjj4wQGdTS/1R1NX68LbZrYAp0bknqkBA6eOociZVqAikqJidN/wUNfbAjySmDFHteNzblbNOcbwx6PLpOh8r2xqp5gKpULntwDhV43VrVKIFe8vgQvz92Buz5ajS/XHsC+E8WGSiXv/5yHCx7/EfuVKxkzDHlhAfo/Oy8i73XHB6vw5xlrI/JeNcldH63xGow35etcfL4qeIPFiHMaI63NEBOJGwAu71mzJ1nyHYw0a6N3P+hjZ0r9WuKeXfTKPVqlwfpch1t1WLf/FDKzc9y9YDypVxwq9ds6cOocPliWh8tfWxJ00JJvA7tSSkz+cmPQuNS++vtPFmPb4dP494KdeP2nyJaRIunHzUfwrYnzwFdWSrwweysKTofWKHCa/y7bi7/O1G+wjHppId6Yb69FyGMmcXdskm51CJY6e957Qq0/feTdfW7aQu+RiL7J13NQzsVTf8K+465W5yNfbXSfBApOn8f+E8V46jtXl7olO4/hue+3oNcTP4YU6wfKCWTJzmPIP1mMzOwcvLVwF9buO4myCu/IPCcS+/vXAcpH0jXNbWZ2jt9TR4vOa5ZodhWcQVmFxr0LCYx5eRGe/2EbXvxxO86VRqcVpWfYiwtw5etLIKWM6uCpX3YfxxvzdyHbwFVYLNtx9EzQjgLRFjOJu2/r+mhZv5bVYVhmQQh9w404otxk+Wj5PvzpozU4X16BC5+Zi0HPV/V133n0DN5atBuF50KroaqpRwDYpEx+9dysrbjKZ/6XkrKKoCUO9ekVeSfwn0XeA5+0Upx6vKNFJRjxz4V44tvgI1+tnjVyz7GzWJ9fiA9/2Yu2k7/X7acfaepVWLAb8zuOnDZ9sesfcg973ej2tGh7Ad6sYVMkxEziBoC3bu5ndQiO4dtyW+vT28U3X3Z+9IeIHXv7EVc3xr98vh57j+vXw6ct3BV06l5Pn/rUKfcqVw2+b6EO6gGA5buDD0QKlpOW7TqOLo/NCvkEFqrPlYnLtCY+i5SF2wtCmgY4Z8MhjPrXIvzpo9WmxQS4FilRf787jpzGs99vcX+Gb3l3hXtgmVWu+vfSqB4vphJ39+Z1rQ7BsW57b6XX12b2e97kMcXsPGX2Qi3nyioM9tIOzaDn57sHQXl+n3p9wtUEsTG/ULMu/9pPO1BSVmn6NAKh/iyW7z4ecu+JT1e6Rueu26ffbdXT3R+7SnKzN7luVEsp8fbi3Zo/p0i5+Z0VmL5oN45aXHuvqJT48Je9KC2vxFqDP69IiYkBOBS6YHXbSZ+tj0pPnXDvwv96+i9B9wmU8IwM4FEb3L96fQni4wR2PTvO+/lqVAnmbzuKYZ0bh/SaPOUKwmid28jPxlco34vW/DFbDp3G0zlbsPd4MZ6a0CPk4wfT8/HZ7nl/fGM9cOocjp0+j16t6lX7/U8Vl2JXwRn0axN4/qPS8kp8tTYfj/4vN+iYCzPEVIubjPtnkDlF9h4vxoe/BJ4bJRI25Ou3CI2OigyHOlnXkaISd3lgc4B52n37hO8uOIMN+fqtrXeX7MGXa/znZve9wvG159hZvDpvh1eSVksxerk1MzsHkz5bp/NsaEor/E+oZ86XIzM7B9MXuerJWrVv9WZvoIFm4SgqKXeXrsoqKr1uLl8y9Sdc+UbgkkWwPv2/eXs5rnlTu9eSuqjKnM1H0OnRWe4Jy05Z0E+eiZtszUi/8mB+CVDHLimrxPr9p3DRs/PcI1GfzvGeadC3hfvIVxvdvVeG/3Mhzga4ennyu82Y9Jn/3OyBrNl3EsNeXICX5mzX7J8fqFX85Rrj0zJsOVSExTtcN7Wf+m4zMrNz3CWk576vqhmr5SS1Zfn+0jxXHBqnkEiU2CorJTKzc/D24t0B97tu2jJ09JkHSI0zMztHc2rgGSsCN0bUMp7WVc2dH7rq+OoI5SU7rZtpMuYSd6wNxKnpItFb5l9z9a8uKqV03yzVff2c7fj1W1WtsI+UWRp9hzd75qz8k8VBV0HS47m6UqWJ04yPfWUxbn5nBQDgnSWuHjlqq16rfqyuOKU2WrVar+qN03B64pQp37TvCdTXYZ0RqLuUmTjfXuJK/G/M34nVe13z1+tNDwy4ekmptL63qs+i67lo9e7REnOJ++1bs6wOgSLEt++5p4XbCyLSp1mI4HXd/y7bi+V7/FvtNwSoIV/75jLNVZA8TZ21Negfv/bMl97bfE8gUko8NNP/2D/v0p6iQI86jXCRWqJRDqu2qmdtPOy1/+HCEtyljB/wDDvUfvCer31r4S7NvvlGXq+eSF+YvQ3XvOm/1KB6I1Y18qWF7sd6FZXM7BzMWBGZUZbhiLnEnZwQb3UIFAW3vrsCORuNr5Kj50jReXwU5PJZz/4T+l3mtFqDvqNZpy3chUe+8h7R6dsf2kj/6M9W7vfaL/dAET5b5d/av/E/yzVPdp51Ys+nf1BGkq5X7kOcVqZHUFvlpT6DlwY8VzWUX32f+VuPouvff9BdXPtcaQWKS13vu+VQEV6Y7d2t77lZoXfzc48T8KnbZGbneN3befo7/RZ9JOb4N1NM9ioZ2jnD6xJ7TPcm7u5KFDumBBpJGYL1Eb6RdqjQP6H//etcfLBsr9923yXXjFx+Swl84nGyefHH7WhYO9n99fe5+ie0fy/YhSGdMtCqQWrV6w2OCrz1XVdZpdhAC1pNe2orf3XeSfRtXd9vv35Pz0FxaQXaZaThwMlzOF9eidsuaWsoHt1jKwcPtpb4aZ/Rxp7UxG3GZFSREHMtbgCYdpP3QJxuzdi/OxYdD1CvjAbfFueNby9HaXmlu2XqSStpA/6twnifbOM5UlX1xZp8ZPvMveI5n0igUYSzcg/h8teWeJV53loU+CYg4FrdSK+mrEVKib9+vh7/WeyqnZdXSqzeewI3vb3cq4WvngR2F5x191IJt7GrXlWE0ytJvYDxnIxKz9tL/JcqNFtMJu6UxKpyycXtG+Ke4R0sjIZqktMlZX7JN5Cfth5FZnaOu9WrN+OiJ60aq9GJoHIPuHpN+HZ5DEZd3chTsHsMn3vcnK2orMSkz9Zjyc5jeCZni7s8oiXcMoX66hUGlud7WefGtd1LJTGZuAHgqSu7A3DNYRIfJ/DbizP99mnTMNVvG1E4+j0919BMhL5en78T6/afwsNfhP5aAPi/X7Rb9GZSF67W4lsCev/nve4pCN7/OQ9TA9Su/1aNn58nz5z7WpBFQl6eq/282XOvhCsma9wA8JuL2gAArr+wFQD4zXkxoF0DjOzaJGiXI6JQrdDogWKE2XOdRFqgYd6+g3N8a/cfLNuLFvW0J4ULdEIw4ob/VJWBgg0009P7yTl4ZFzXsOIwU8y2uOPiBG4emOnuZeI7MKBOSiL+MKidBZERafPtnuZkRrpqVqfHSDTNsPHvI2YTN5HTfO/TL9rO1AEteg6aOMlU1Ni4WlLjEnfj9OTgOxFRQFoDWmLN7jCW4DNbjUvcY3s0hRDAVX1aAADW/320xREREYUmZm9O6mlZPxU7nxnn7rJVNzVRY59apk5WT0QUjhrT4r6hf2ukJydgXM9mQfvZLnl4OLo3rxNwnyZ1wi+5JMabP20pEcWeGpO4OzSujY1PjNHtguQre2yXgM8vyx4RibCIiEJWYxK3Udf1a2lov0jMOxyNhQKIKPYwcQP44Hf93Y87N00H4KqFB+I7x4Tq+Wt74ukJPTB30uCgx20cpNxyaQfOLU5E/pi4AQzulOF+PKZ7UwBA20ZpWP434+WQ9hlpAIDrs1rhpgFt0KFxOhKC1NI/++PAgDGN7BramoREVDMwcSvuHtYegKtHiapJnRS//WbfPxgvXtfLb3vOvYOw8XHvroWX92zmt9/9Izti5p0DkTd1PJp71NvbNUrz2m/8BU1x04A2oX0TOtSTChHFBiZuxV/HdEHe1PG6JRBV56bpuFajDp6SGI/0FO+uhS9oJPj7R3ZCVqb/CtKPX9Hd6+vrs1ohIT4yv55w5zcmInth4jZRYgiJt3frel5fBzuBhKJ3q3rBdyIixzCUWYQQ9YQQM4UQW4UQW4QQ+sXZGHXbJZlIT9Yer/Tc1RdU+33vHdER7TPSUCclEV2UG6NaVj86EuM1Si9GNK3rX/IhIucy2iR8BcAPUsouAHoBqHFzoU75VXdsfGKM5nNZbfyXZDJq0qhOmPeXobrPD+/SGK0bpHotTaXljsHaMx3+9uJMNKqdjHuGcTEJolgRNHELIeoCGAzgHQCQUpZKKSO7SJ/DdWyi31IO17u/vRCLHhrmt71uLe96+u2D2qFOSgK++/OlXtsbpiUBAFKTrV9EWZ0fhojCY6TF3RZAAYD3hBBrhRBvCyH8uikIIe4QQqwSQqwqKCjwfxcKW6Cqd0Z6MjY8PgY9Wnivr6nOTDm2R/XKLEZ9ffcleHpCD9TTmPtFFax7JBEZYyRxJwDoC+BNKWUfAGcBZPvuJKWcLqXMklJmZWRk+D5NBgSqcQMImBSDadsoDXlTx1f79cHExwncNKAN2jbS7nq44MGhERltSkTGEnc+gHwp5XLl65lwJXKKMN/l1Xw9dnk3tKxfC4sfGlbtJJiWZE3JpHaK8YkofacdmPNA8FGotw+KTJfHXJ37GER2EjRxSykPA9gvhOisbBoBYLOpUdVUQZJxckI8ljw8HK0apKK/0hd89v2D8doNfQwfYtOTl1U7vLE9muo+p55IhnXWHu1ZPzXJ8NwsWZneN3vjDJRYLtToG18dtXV6DgWj1befap6L2kbmcxiM0U/pnwF8JIRIArAbwG3mhRRbcp8YAwGgrKISp4qNLQablBD8QuiViX2w98RZdG6a7p5fRYuBpf8My/BYPSgxXqCswv/N7xnWARP7t0Lj9BRkZue4twebSjcS1j42Cn2emmP6cbTcNbQ9Zq7Ot+TYZB8zbh8QleMY6g4opVyn1K97SiknSCkDLzhXQzw4uhN+f2ngS/TayQlIS05AvdQkZOrUf309M6FH0H1qJcWjS1PtOcPHdG9i6DgLHhxqaD/V5LFVq17rfUDj4gQap2v3G9datMJTZkP/ib1a1Ktl+ORTX+lBE0g4XTeJgjFydRiR40TlKA4W6IbgPcM74rHLu0XsWJGa5vXafq3cj2WAFU+NJDpPtZLi3XO51K2ViOFdjE2C9dL1rqH/k0Z1Cvjz6q9xmbk0e3hIMQaSPbYLnrmq+oOl9Nw3omPE35MoECbuIJY8PBxrHhsV1WOGW90Y2L5hROLQkqyUcYRw9THv2izwSkFA1fS0KYnxQa9QzGakQTTvL0NCuhpp1SDwFMBEkcbEHUTt5AQ0CLFlaie/vrCV/pPVOEOE8hLPmngwA9p5t7a9ZzQMflQjcUlprNbePqM2Mhul6XZt9BVsmTuraM1iaYbfRWESs3EX6N8Yr4mYuG0oklWyZnWNLdUWSK3EeLx324U+W11Rygjc/by6Twt8+PuLvLbl3DsI66eM1nlF9XRtlo6GafonE9+rh/kGW91GrjqsEKynS68ITT525xDt6RYiqbIy+D6vhtC7qjqMLI4SLUzcNhTBjiAB1U5J8Jq4SmsAkBDAlqcu0+3m98/re2Fk1ybopDHs32hOT0qI85vCNiUx3j2sP1I9Y4Z2boy6qYlY9Ff/KQRWPzoSs+4bFJkDGVDdboeRNChCKyw11pi3PtIqg3wIHru8G1qbXLLq0DjdUKktGpi4bSRSIwuNtoLj4wTuGupaQKJL03TcMjDTwJt7f9m9eV28fWtWSFPY+rpdmSDrziHt0aFxbYzqFtnL4hWPjPCqWTes7V36SoqPCzqJl6qxT/mnOlPm/nFIO0wMVMIyIBK9Y265ODILdQDmt0ZrBRk49puLWpt6fCM+vv2i4DtFCBO3jQxo57qpqNV6tYreOcDIScbIPi3r10L7jNoAgHYZtTF30pCg9xSGdvafUkGNU6vV1Tg9xasrZppPazdQz5s9z43DpFGdAAATejfHd/deipsGtHZfDUz5Vei9im41cIJ848bAg5ONdi1NDZDwAvVi8j1BBZOh0wU0UlKTEvBz9vCInCB6tqxr+P6Fr0Dz5HdoXLu6IYWMidtGru3XEiseGWH6wgcdG9dGI58Wpt4Hsp3PsmeRLuMYuTjw3eWBkZ389umjLETxzFVVfeBrJcajqYHL+EAxCCHQQlliTghXH/WnJ1yAB0a6ugAGWlT64cu6aG73XLJOT60k/T/Nl3/dO2BCTvK4+um4ufQEAAAOsElEQVTRvK7ufuke0xAM6ZSBv42rindi/9BasMkGBo2Fq3m9WujQOF3z3oORz9E/dW7Wrn50ZFhxqVM0pCZFr/xlfaGNvOgNXomkOZOGGNrv/dsuRHedP/xwqjp1UhJQVFIOwNiNrXo+U9hW+PyVek6eVSuxKqEZnXck2N+8+rznue3WizNx04A27tq81omvn0mDfSb0aYE1+/THwM2dNARbDhcBCHw1keLxs5p2Uz8cPV2CZ7/f6rpqCfHGQkpiPFY9OhJZT88N6XV6WjdIxb4TxZrPtW2UhtwnxqDHlNle24N9JttrtIhvGdjGcJkMAK7s3Rxfrjngte3Zqy/ApNGdonrfgi1ucvP9Ix/aubFfl75I9CJZmj0cqx4dibyp43GzgbJB4zopmDtpiPtSNCk+Ds11VvXxzJ/xccJQ979gN77U5z0nARNCeN1Q1fq59G/bwPDJo4WBVrgnve/q3uEd0LphKsZ0b6rEVfVcusZEX5798sPlexUXjrFBuv/5JkkJGfR78Hx6tDK6+H6Nq7dAstr4DxJLjI+LSO+tUDBxxyD1b7W6LYBmdVN0RwOqtUHP1lqo0lMSQ/4j79C4ttcN0J8nj6j28X0FOxep5QYjI0Xb+Azb9/0d6NWOgyWddgZrspf49BTx/NaSAtxAltI1iRnguu8QrZ5Nnl6/sQ8euqxz8B01VIZYcntoTBesfnSk4TEa6qjeQFcw0cTETW5jezRDh8a18fHtA/DAKO2WyCs39MF7v73QUJ020ow1CiPfX6tb8zrY8uRlGHeB/mIURhd3VmvTE3xWAwo2pe9nd7qWeR3U0ZWY/6xzYvVNK+qVwLu/zfLarvZq+eJPF+P2QW2RkhiHpnVTMO2mvvj3b6pujP46q5V75CsAbH0q8OyS4SyWISDcJ1Ej0z/cPKCqV0yllAGT8F/HeJ8Q4uNESCUStaHieZK3ckZIJu4azvOD2CAtCXMnDQl4x71OSiKGGZyj5O1bsnBVnxYRu4S2ciGGYN3RVMFCVH/cPVrUDWlhi0a1k7H1qcvw3m8vdH99y8Dg3fnU49WtleiV1Kde09MdxyPju7lPPJf1aIZ6qUnuz0WL+rXw4R8uwsw7B2LKr7oFvdJa/HBVH/k/hjEwx/d3rfW790zUaUkJaFk/FT8+MFhzfqG7h3WIyCnd82f44nW9TF2cJBAm7hhm9ViBXq3q4V+/7h2xGdOeveoC9M9sgI5N9LtdGT3UkE5VXQq7RWDkY/N6KWhWNwV/V7oHhrpghZGTUkpivN9AJb/38fl68tiuaN0gFV2b1dHsRqlHLQmo75eV2QC3GRja7lnrDXXStB4tqn4Pobzy3hEd3fcyOjVJR0Kc989o+s39QoojoEjOkxwGJu4YpF6uBpqn24l6taqHz+4c6K7FaunZ0lhXyv/+rr+77OB7GV0dyQnxWDZ5BIZ3aYKnJvTAtz6LNqv0/u6DlUqMvpfvpv5tG2DRQ8OQmpSAG0Ps4gdon1Ca1DF2BRXq+bpNwzR3aSe8cQLeP4XRyo1a9WohoxpXgGpcF0dotGm4mLhjUGpSAmbcPgDv3Oo7v4g+q1vnkVKdBRsiPYfyzQPaoF2G9lWB3qHs9vMf3sXV62JwJ/9WuudJplFt/bpydU5Gnh4cHbjHR6iN385N0/HCtT3x0vW9A+530wD9E1x7nd9rtDFxx6iB7RsGXbgAqBpFaPY8D3ak1nUj0cXRKN1kpmwOpcRyY4jDvEPJo/3a1Efe1PGaVzBGE7IQwKd36K8I41mu+uB3/QHA6+bkPcM74mkDi4qE4rqsVpp/F4+O74oLlSXz6tVKwobHR+OSDlXTI9dJqf5C3WbgAJwarm2jNPznlixT5/A2y+z7B2N9/qmw3yeaVctgyTaUWLRmJYxGy907b+sfsVOTdFzUTv9zdeeQ9rh7WAfkHih0t+x9BzsF+nlE8mb1Hwa1w7nSCqzMcw1sqpOSiFqJrvR4YWZ9PGvCAhzhYOImjOpmbKkzuwm23mYw7r/7KGVurR4IDdKScOJsqX9M1RT4W4lMpmterxbyT54Lut/oIEvoxQlXDd5z5SPfi58rejbH12sPuCdDMyJ7bFc8NHM9LmhZL8zFe13B3D6onaGr12hi4qaYZKSrXFWrzh49BQDj/cFV/7v7Ekx4Y6lJ0WibdlM/9DWwKHOwXiVa32ul++ak67m6qYmY+aeLQ4rv2n4tw+pj7ft5CPV3Eg1M3BRzjPatVf8czS5xz7pvUMC5RVxBVO+9zZ6QTItn/+lAOS3YPV+t117VpwXeW7oH1/Rt4f+kydoqE6o1icL84uFi4qYaq+rmpLnH6dqsTtBVcty1XY3nQumHHO224W2XZOo+59tSbdcoDbuPna16XuM1mY3SsOFxY/O7RNrobk3x+o19MDrC88Gbgb1KqMay0wWwOplXikavErUfshHV6Q4ZjruGdtB9zjeUn3ymY7VbCSIpIQ6X92yOJGXiLZuMtdHExE01lpGeC9Hy/DU98fy1PTHj9gEY2K4hBrar3kCPQH3So50ntRKz57wn4cbzq17NAQBX9NKfQyYcdw5tj6SEONOm5w0HEzfVWNf2c0201M3CVdpv6O+KoWndFFyf1QodGtfGjDsGoFZSPGbdN8hrYQgj2jas3sou0fLhHy5yL1Ic7nmkfUZt5E0djw6NzRkhfGFmA2x/eqzf5FW/uag1/nCp+SvbB8IaN9VYl/VoatkkQaoHR3fG/SM7aa7ZaaQ27snq78Uo6dNzxGmesUGfbra4iSwkhAhroeWQjhWVowB9Wwfu6TL16p4Y2jkjIpN71VRscRPVEO0aRWeejQ9+fxEOF5boPt+teR28f1v/qMQSq9jiJqohojX6r3ZyQlRXPK+J2OImqkEWPzQMZ0vLrQ6DwsTETVSDtKqBs0DGIpZKiIgchombyOHuHd4B3S3si07Rx1IJkcNNGt0Zk0aHv/waOQdb3ERkOt/RhxQetriJyFQLHhyKurXstRCB0zFxE5GpMhvZe/4UJ2KphIjIYQy1uIUQeQBOA6gAUC6lzDIzKCKyt1cm9kbDtGSrw6ixQimVDJNSHjMtEiJyjCt7R39pMarCUgkRkcMYTdwSwI9CiNVCiDvMDIiIiAIzWiq5VEp5QAjRGMAcIcRWKeUizx2UhH4HALRu3TrCYRIRkcpQi1tKeUD5/yiArwD4TaYrpZwupcySUmZlZGRENkoiInILmriFEGlCiHT1MYDRAHLNDoyIiLQZKZU0AfCVsj5cAoCPpZQ/mBoVERHpCpq4pZS7AfSKQixERGQAuwMSETmMkFJG/k2FKACwt5ovbwTAKQN9GKs5nBQr4Kx4Gas5IhFrGymloZ4dpiTucAghVjllSD1jNYeTYgWcFS9jNUe0Y2WphIjIYZi4iYgcxo6Je7rVAYSAsZrDSbECzoqXsZojqrHarsZNRESB2bHFTUREAdgmcQshLhNCbBNC7BRCZFsYx7tCiKNCiFyPbQ2EEHOEEDuU/+sr24UQ4lUl5g1CiL4er7lV2X+HEOJWE+JsJYSYL4TYLITYJIS4z66xKsdIEUKsEEKsV+J9QtneVgixXInrUyFEkrI9Wfl6p/J8psd7TVa2bxNCjDEjXuU48UKItUKI7+wcqxAiTwixUQixTgixStlm189BPSHETCHEViHEFiHEQBvH2ln5mar/ioQQ99siXiml5f8AxAPYBaAdgCQA6wF0syiWwQD6Asj12PY8gGzlcTaAfyiPxwGYBUAAGABgubK9AYDdyv/1lcf1IxxnMwB9lcfpALYD6GbHWJXjCAC1lceJAJYrcXwGYKKyfRqAPymP7wIwTXk8EcCnyuNuyucjGUBb5XMTb9JnYRKAjwF8p3xty1gB5AFo5LPNrp+D/wL4g/I4CUA9u8bqE3c8gMMA2tghXtO+0RB/KAMBzPb4ejKAyRbGkwnvxL0NQDPlcTMA25THbwG4wXc/ADcAeMtju9d+JsX8NYBRDok1FcAaABfBNWghwfdzAGA2gIHK4wRlP+H72fDcL8IxtgQwD8BwAN8px7ZrrHnwT9y2+xwAqAtgD5R7a3aOVSP20QCW2iVeu5RKWgDY7/F1vrLNLppIKQ8pjw/DNfEWoB93VL8f5dK8D1ytWNvGqpQe1gE4CmAOXC3QU1LKco1ju+NSni8E0DCK8b4M4CEAlcrXDW0cq9ZCJ3b8HLQFUADgPaUE9bZwzThqx1h9TQQwQ3lsebx2SdyOIV2nTNt0xRFC1AbwBYD7pZRFns/ZLVYpZYWUsjdcrdn+ALpYHJImIcTlAI5KKVdbHYtBl0op+wIYC+BuIcRgzydt9DlIgKsM+aaUsg+As3CVGtxsFKubci/jCgCf+z5nVbx2SdwHALTy+Lqlss0ujgghmgGA8v9RZbte3FH5foQQiXAl7Y+klF/aOVZPUspTAObDVW6oJ4RQZ6n0PLY7LuX5ugCORyneSwBcIYTIA/AJXOWSV2waK6T2Qid2/BzkA8iXUi5Xvp4JVyK3Y6yexgJYI6U8onxtebx2SdwrAXRU7tonwXVZ8o3FMXn6BoB6J/hWuOrJ6vZblLvJAwAUKpdQswGMFkLUV+44j1a2RYwQQgB4B8AWKeVLdo5ViTdDCFFPeVwLrnr8FrgS+LU68arfx7UAflJaN98AmKj05GgLoCOAFZGMVUo5WUrZUkqZCddn8Scp5W/sGKvQX+jEdp8DKeVhAPuFEJ2VTSMAbLZjrD5uQFWZRI3L2njNLOiHWPwfB1fPiF0AHrEwjhkADgEog6uF8Hu46pXzAOwAMBdAA2VfAeANJeaNALI83ud3AHYq/24zIc5L4bpE2wBgnfJvnB1jVY7RE8BaJd5cAH9XtreDK5nthOtSNFnZnqJ8vVN5vp3Hez2ifB/bAIw1+fMwFFW9SmwXqxLTeuXfJvVvx8afg94AVimfg//B1cvClrEqx0mD6+qprsc2y+PlyEkiIoexS6mEiIgMYuImInIYJm4iIodh4iYichgmbiIih2HiJiJyGCZuIiKHYeImInKY/wdbXnV14eiOAAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(loss_change_self)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_loss_change_self = [x for x in val_loss_change]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x1255649b0>]"
      ]
     },
     "execution_count": 102,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAHvhJREFUeJzt3XtwXOWZ5/Hv093qltS6y7IsX2V8gQDBYMQtWRJmCJPAsLAzQ3ahMpX7smyxk8vM1FbYVJLa1P6TmtTOkM0sLEV2kkwlTDYMyRAmCSSESZgkmIiLDcY2NvhuS5Yt62LJunU/+0cf2e22rnbL3X3696nqUp9zXnU/PmX9+u33vOccc3dERCRcIoUuQERE8k/hLiISQgp3EZEQUriLiISQwl1EJIQU7iIiIaRwFxEJIYW7iEgIKdxFREIoVqg3XrRokbe3txfq7UVEStJLL7101N1bZmtXsHBvb2+ns7OzUG8vIlKSzGzvXNppWEZEJIQU7iIiIaRwFxEJIYW7iEgIKdxFREJI4S4iEkIKdxGRECq5cN/RNchXn95B79BYoUsRESlacwp3M/u0mb1uZlvN7DNTbDcz+5qZ7TKzLWa2Mf+lZrzdc4KvP7eLrv6RhXoLEZGSN2u4m9nlwH8ErgU2ALeb2dqcZrcC64LHvcBDea7zlJrKzEm1Q2MTC/UWIiIlby4993cAm9x92N0ngF8Cf5zT5k7g257xAtBgZm15rhWAmkQm3E+MKNxFRKYzl3B/HbjRzJrNrBq4DViR02YZsD9r+UCw7gxmdq+ZdZpZZ09PzzkVPBnug6MKdxGR6cwa7u6+DfgK8AzwU+BVIHUub+buj7h7h7t3tLTMelGzKZ0allG4i4hMa04HVN39G+5+tbu/BzgOvJnT5CBn9uaXB+vyLqlhGRGRWc11tszi4OdKMuPt381p8iTw4WDWzPVAv7sfzmulgWQ8CHf13EVEpjXX67n/o5k1A+PA/e7eZ2b3Abj7w8CPyYzF7wKGgY8tRLEA0YiRjEcV7iIiM5hTuLv7jVOsezjruQP357GuGSUTMY25i4jMoOTOUIXMQVXNlhERmV5phnsipgOqIiIzKNlw17CMiMj0SjLck4mYDqiKiMygJMO9VuEuIjKjkgz3mkqFu4jITEoy3CenQmZmYIqISK6SDPeaRIzxlDM6kS50KSIiRalkwx10CQIRkemUdLhrOqSIyNRKM9yDy/4O6kQmEZEplWa4q+cuIjKjkg53jbmLiEytJMM9qXAXEZlRSYZ7baXCXURkJiUZ7jW61Z6IyIxKMtyr41HMdEBVRGQ6JRnuZkZNXDfsEBGZTkmGO+hWeyIiMynZcNeVIUVEpley4Z65YUeq0GWIiBSlkg332kSMEyPjhS5DRKQolWy41+huTCIi05pTuJvZZ81sq5m9bmaPmVllzvaPmlmPmb0aPD65MOWeljmgqmEZEZGpzBruZrYM+BTQ4e6XA1Hg7imafs/drwwej+a5zrPUVsYY1LCMiMiU5josEwOqzCwGVAOHFq6kuUkmogyNpXSrPRGRKcwa7u5+EPgqsA84DPS7+zNTNP0TM9tiZo+b2Yo813mWmkQFqbQzMq5b7YmI5JrLsEwjcCewGlgKJM3sT3Oa/Qhod/crgJ8B35rmte41s04z6+zp6Tmvwmt08TARkWnNZVjmfcBud+9x93HgCeBd2Q3c/Zi7jwaLjwJXT/VC7v6Iu3e4e0dLS8v51E1NIgoo3EVEpjKXcN8HXG9m1WZmwM3AtuwGZtaWtXhH7vaFUJOoAHRlSBGRqcRma+Dum8zsceBlYAJ4BXjEzL4MdLr7k8CnzOyOYHsv8NGFKzkjqZ67iMi0Zg13AHf/EvClnNVfzNr+APBAHuuaVe1kz13hLiJyltI9QzU4oKq57iIiZyvZcG+szvTcjw8r3EVEcpVsuNdVVhCNGMeHxgpdiohI0SnZcI9EjMbqCnqHFe4iIrlKNtwBmpJxek8o3EVEcpV0uDdWx9VzFxGZQkmHe1MyTq/G3EVEzlLy4a4DqiIiZyv9cB8eI53WZX9FRLKVdLg3VsdJO/Sf1Fx3EZFsJR3uzTVxAB1UFRHJUdLh3lidCXeNu4uInKmkw70pmQn3Ywp3EZEzhCLc1XMXETlTSYf75LCMxtxFRM5U0uFeFY9SVRHVJQhERHKUdLhDcJaqeu4iImcIRbhrzF1E5EwlH+6Nur6MiMhZSj7cmzUsIyJylpIP98bqOMeHdPkBEZFsJR/uzTVxToxOMDqRKnQpIiJFo+TD/fQlCNR7FxGZNKdwN7PPmtlWM3vdzB4zs8qc7Qkz+56Z7TKzTWbWvhDFTqUpWQGgg6oiIllmDXczWwZ8Cuhw98uBKHB3TrNPAMfdfS3w18BX8l3odJqSCUDhLiKSba7DMjGgysxiQDVwKGf7ncC3guePAzebmeWnxJmd6rlrxoyIyCmzhru7HwS+CuwDDgP97v5MTrNlwP6g/QTQDzTnt9Sp6bK/IiJnm8uwTCOZnvlqYCmQNLM/PZc3M7N7zazTzDp7enrO5SXO0lAdx0yX/RURyTaXYZn3Abvdvcfdx4EngHfltDkIrAAIhm7qgWO5L+Tuj7h7h7t3tLS0nF/lgWjEaKqOc/TEaF5eT0QkDOYS7vuA682sOhhHvxnYltPmSeAjwfO7gF+4+wW7a3VLbYIjAwp3EZFJcxlz30TmIOnLwGvB7zxiZl82szuCZt8Ams1sF/DnwOcWqN4ptdZVcmRw5EK+pYhIUYvNpZG7fwn4Us7qL2ZtHwE+mMe65qW1LsH2roFCvb2ISNEp+TNUARbXVnL0xBip9AUbCRIRKWqhCPfWugSptHNsSOPuIiIQknBvqc1cDUEHVUVEMkIR7q11mUsQ6KCqiEhGKMJ9cV2m596tnruICBCScG+pCXruCncRESAk4R6PRWhOxunWsIyICBCScAedpSoiki004a6zVEVETgtNuC+uTdA9oHAXEYEQhXtrnc5SFRGZFKJw11mqIiKTQhPuOktVROS00IS7zlIVETktNOGus1RFRE4LTbjrLFURkdNCE+7xWIQmnaUqIgKEKNwhM9f9iOa6i4iEK9yX1FfSpXAXEQlXuLfVV3G4T+EuIhKqcF9aX8mxoTFGxlOFLkVEpKBCFe5L6jPTIbv61XsXkfIWqnBf2lAFwKH+kwWuRESksEIV7m1Bz13j7iJS7mYNdzO72MxezXoMmNlnctrcZGb9WW2+uHAlT6+tPtNz14wZESl3sdkauPsO4EoAM4sCB4EfTNH0eXe/Pb/lzU9VPEpjdQWH+jQsIyLlbb7DMjcDb7n73oUoJh/a6qs4rAOqIlLm5hvudwOPTbPtBjPbbGY/MbPLpmpgZveaWaeZdfb09MzzredmaUOleu4iUvbmHO5mFgfuAL4/xeaXgVXuvgH4X8APp3oNd3/E3TvcvaOlpeVc6p2Veu4iIvPrud8KvOzu3bkb3H3A3U8Ez38MVJjZojzVOC9L6ivpPznO8NhEId5eRKQozCfc72GaIRkzW2JmFjy/NnjdY+df3vwtbchMhzyk6ZAiUsbmFO5mlgRuAZ7IWnefmd0XLN4FvG5mm4GvAXe7e0HuVD05HfKwTmQSkTI261RIAHcfAppz1j2c9fzrwNfzW9q5WXoq3NVzF5HyFaozVAFa6zN3ZNJZqiJSzkIX7olYlEU1CQ3LiEhZC124QzDXXcMyIlLGQhnubfWVHNaJTCJSxkIa7lUc6jtJgSbsiIgUXCjDfWVTNUNjKY4NjRW6FBGRgghluK9qrgZg77HhAlciIlIYoQ73/b0KdxEpT6EM9+WN1Zip5y4i5SuU4V5ZEWVJXSV7e4cKXYqISEGEMtwhc1B1n3ruIlKmQh3uezXmLiJlKrThvqq5mp7BUV3XXUTKUmjDfWVzEoD9vTpTVUTKT2jDfVXT5Fx3HVQVkfIT3nAP5rrv07i7iJSh0IZ7fVUFtZUxzXUXkbIU2nA3M1Y1a8aMiJSn0IY7wKqmpC5BICJlKdThvrK5mgPHh0mldelfESkvoQ73VU3VjKecg8c1HVJEykuow33N4hoA3uo5UeBKREQurFCH+9qWTLjvOqJwF5HyMmu4m9nFZvZq1mPAzD6T08bM7GtmtsvMtpjZxoUree4ak3EW1cTZeWSw0KWIiFxQsdkauPsO4EoAM4sCB4Ef5DS7FVgXPK4DHgp+Ftyalhr13EWk7Mx3WOZm4C1335uz/k7g257xAtBgZm15qfA8rWvNhLtuli0i5WS+4X438NgU65cB+7OWDwTrCm5tSw0DIxP0DI4WuhQRkQtmzuFuZnHgDuD75/pmZnavmXWaWWdPT8+5vsy8rF1cC+igqoiUl/n03G8FXnb37im2HQRWZC0vD9adwd0fcfcOd+9oaWmZX6XnaF1rMGNG0yFFpIzMJ9zvYeohGYAngQ8Hs2auB/rd/fB5V5cHi2sT1CZi7OxWuItI+Zh1tgyAmSWBW4D/lLXuPgB3fxj4MXAbsAsYBj6W90rPkZmxZrFmzIhIeZlTuLv7ENCcs+7hrOcO3J/f0vJn3eIa/uXNCzPGLyJSDEJ9huqktYtr6BkcpX94vNCliIhcEGUT7oDOVBWRslEW4X7p0joAXjvYX+BKREQujLII97b6KpbWV9K593ihSxERuSDKItwBrm5v4qU9x3UZAhEpC2UT7h2rGukaGOFQ/0ihSxERWXBlE+5Xr2oEoHNPb4ErERFZeGUT7pcsqaU6HuUljbuLSBkom3CPRSNctbKBzj0KdxEJv7IJd4CrVzWxvWuAE6MThS5FRGRBlVW4d6xqJO3wyj713kUk3Moq3K9a2UA0Yvx617FClyIisqDKKtxrKyt47/oWfvDKAVJpzXcXkfAqq3AH+Pcdy+keGOVXO3WVSBEJr7IL99+/pJWmZJzHOw8UuhQRkQVTduEej0X4d1cu45k3uugdGit0OSIiC6Lswh3ggx3LGU85P3zlrNu8ioiEQlmG+zva6rhqZQNff24Xh/tPFrocEZG8K8twB/jqBzcwOp7i/u+8zNhEutDliIjkVdmG+5qWGr5y1xW8vK+Pv/z+Zt44NKDLAYtIaMzpBtlhdfsVS9l2eICH/uUtntx8iItaktz3njX80cZlVETL9nNPRELACtVb7ejo8M7OzoK8d66jJ0b52RvdfHfTPl472M/yxiru/721/MnG5cRjCnkRKR5m9pK7d8zaTuF+mrvz3I4jPPjznWw+0M+yhiruuno5H7h8CZcsqcXMCl2iiJS5vIa7mTUAjwKXAw583N1/m7X9JuCfgN3Bqifc/cszvWYxhvskd+eXb/bw8C/fYtPuXtwz14P/5I0XcceGperNi0jB5DvcvwU87+6PmlkcqHb3vqztNwF/6e63z7XAYg73bD2DozzzRhff/s1ednQPsrg2wUfe1c6HrltJQ3W80OWJSJnJW7ibWT3wKnCRT9M4zOE+yd351c6jPPr82zy/8yg1iRifed86PvKudh18FZELZq7hPpdUWg30AH9nZq+Y2aNmlpyi3Q1mttnMfmJml8234GJnZrx3fQt//4nr+Mmnb+Sa9kb+xz9v49YHn+e7m/bpBiAiUlTm0nPvAF4A3u3um8zsQWDA3b+Q1aYOSLv7CTO7DXjQ3ddN8Vr3AvcCrFy58uq9e/fm8Z9yYbk7z247wl89vYMd3YNUx6PcsWEp91y7kiuW1+vgq4gsiHwOyywBXnD39mD5RuBz7v6HM/zOHqDD3Y9O16bUhmWm4+68sr+Pxzbt46kthzk5nmLt4hr+8J1t3HRxCxcvqaU6XtanE4hIHs013GdNHXfvMrP9Znaxu+8AbgbeyHmzJUC3u7uZXUtmuKcsbndkZmxc2cjGlY184d9eyo82H+JHmw/xtV/s5MFnd2IGF7fWcs+1K/njjcuorawodMkiUgbmOlvmSjJTIePA28DHgP8A4O4Pm9l/Af4zMAGcBP7c3X8z02uGpec+nSODI7yyr4/thwf5xfZuNh/oJxGLcOnSOq5YVs971rfw7rWLqKyIFrpUESkhOompyGze38ePNh9iy8F+Xj/Yz/BYiqqKKO9d38Itl7byrrXNLKmr1Fi9iMwob8Mykh8bVjSwYUUDAKMTKV54u5efvdHFz984wk+3dgFQVxljw4oG/uDSVm65dAlL6isLWbKIlDD13AvM3XntYD+b9/exvWuQF94+xls9QwCsXpTkmvZGrl3dzLXtTaxoqlLPXqTMqedeIsyMK5Y3cMXyhlPrdh0Z5NltR/jdnl6e3trN/wvu99pal+Ca9ibes76F91+2hPoqHZwVkamp517k0mln55ETvLj7GC/uOc6Lu4/RPTBKPBrh+jXNXLe6iY5VjWxY0aCDsyJlQAdUQ8rd2Xygn6c2H+JXO3t4s/sEABVRY8PyBu66ejl3XrmMqriCXiSMFO5lom94jJf2Hud3e47z3PYj7OgepL6qgmtXN/GOtjo2rmzg2tVNOpFKJCQU7mXI3Xlxdy//8Lv9bD7Qx56jQ6Q906u/bnUzd1y5VGP1IiVO4S4Mj03Quec4v951lJ9u7WLvsWEAVjRVccmSOn7v4sXccmkrLbWJAlcqInOlcJczuDuv7u/jX3ce5c0jJ3h1/3H2957EDDaubOT9l2Xm1q9eNNUFP0WkWCjcZUbuzo7uQZ7Z2s3TW7vYemgAgDUtSW66eDE3XNTMNaubNIQjUmQU7jIvB44P8+y2I/x8WzebdvcyNpEmYnDZ0npuWNPMu9cu4tr2Js3CESkwhbucs5HxFK/u7+O3bx3jhbeP8cq+PsZSaeKxCB2rGrlxXQs3rlvEpW11RCI6Y1bkQlK4S96cHEvx4p5e/nVnD8/vPMr2rkEAmpJxbljTzLvWNNOxqonVi5K6ebjIAtPlByRvquKZq1e+d30LkLmc8a93HeX5nUf5za5j/POWwwBEI8ZFi5KnhnGuWtnA4lpd/EykENRzl/Pi7uw+OsSWA/281XOC1w72s+ntXk6Op4DM9XDeuayBdy6rZ31rDcsbq1ndkqQmoX6FyLlQz10uCDPjopYaLmqpObVubCLN5gN9bDmQuXb9lgN9PLu9m8l+RMTgHW11XNPeREd7Ix2rmmitS+iKlyJ5pJ67XBAnRifYe2yI/b0n2XZ4gN/t6eWVfX2nevjNyTgXL6llUU2CmsoYFy1KcvWqRi5bWq9xfJEs6rlLUalJxLhsaT2XLa3nA5cvAWA8leaNQwO8tPc427sGeLP7BFsO9DEwMkHv0BgAsYhxUUuS9a21XLKkllXNScYm0oxOpLlieb1m7IhMQ+EuBVMRjZxxh6ps3QMjvLT3OK8f7GdH1yCv7u/jqeDAbbaW2gRXrWhgXWsN61trWbu4hjUtNbr8sZQ9hbsUpda6Sm57Zxu3vbPt1LrBkXEO9p2kMhYlGjFe3N3LL9/sYeuhfp7dfoRUOjPEGI0Y6xbXcNnSei5fVsdlS+tZ0VRFS02CWFRDPFIeNOYuoTA2kWbPsSHe7B5k++FBth7q57WDAxw9MXqqTTRiLK5N0FZfSVt9FW31lSxtqGJVczWrmqtZ3litHr8UPY25S1mJxyKsb61lfWstt19xev2RgRG2Hh7gUN9JuvpHONQ3QtdA5qDus9u7GRlPn2prBi01CRIVEaJmrGiq5h1tdaxoqqapOk5TMvNoTFbQWB2nQt8CpIgp3CXUFtdVsrhu6hOp3J3eoTH29g6z79gwe48Nc7BvmImUM5ZKs/voEN/89R7GUukpfz8ZjzKRdibSTmUsQk1ljLWLa7i2vZn1rTUkEzHqqipoq69kUU2CVNpJu+vbgVwQCncpW2ZGc02C5poEG1c2TtlmIpWmd3iM3qHTj+NDY/QOjdN/cpyKqBGLGiPjaQZOjvP6oQH+5tk3mWm0c1lDFVcsr2dlUzWNyfipbwWNydPfDuoqY5r3L+dlTuFuZg3Ao8DlgAMfd/ffZm034EHgNmAY+Ki7v5z/ckUurFg0wuLaynldRqF/OHPgd3hsguPD43QNjHDsxCixYMrm9q5BXjuYOQg8NjH1t4JYxGiojtMcDANNhr5hDI1OEI1Y5thBw+ljB0vqK6mr1CWaJWOuPfcHgZ+6+11mFgeqc7bfCqwLHtcBDwU/RcpOfXUF9dWzh6y7MzyWynwbyP12MJz5dtA7NMrxoXHe7D5B79AY7k4yEWMi5RwZHCGd8w2hJhGjtjJGVTxKdTxKdTxGdTxKMvhZHY9SnYiRzNo2uVwVtEsmsrbFYzqJrETNGu5mVg+8B/gogLuPAWM5ze4Evu2ZqTcvmFmDmbW5+9kTk0UEyAwLJRMxkokYK5py+0uzm0il6R4c5XDfSQ71j3C47ySH+0c4MTrBybEUw2MTDI+lOHZijP1jwwyPpRgeSzE0OsFE7qfCDCqilhX2USqiEcyMaAQiZpgZEYOKSOa4Q+2pRwWJWISKaCQzfBUJfkYjRCNGRdSIRiLEIpZ5BG0yz7PbGBWTy5EI0ahREcmsj0UjZ/xuVCe0nTKXnvtqoAf4OzPbALwEfNrdh7LaLAP2Zy0fCNadEe5mdi9wL8DKlSvPo2wRiUUjLGuoYllD1bx/d2wifSr8h8cmGBpNnX4+lmJ4dOKM5ZPBh8LwWIqJdJpUOvPNI+1OyjPPxybSdA+M8FbPBIMjEwyOjDOeurBTrc0yQ1pThbyR+RCKmBGJnH4++eGUeT7ZdvL1znydU9vt9GtOLuf+zqnfnOI1775mBZ+88aLz+rfOZi7hHgM2An/m7pvM7EHgc8AX5vtm7v4I8Ahk5rnP9/dFJD/isQjxWJyG+X9hmJd02hlPp5lIOeOpNOMpJ5XOPE+lnYl0OjPjKOXBz+zl9Kn1p9qmZv6dVDrNeNA+N949qCfljjukgw+ndPDhNHkS3OTB8MmAOr185obT232Ktqe3ZS9PPllUs/A3pZ9LuB8ADrj7pmD5cTLhnu0gsCJreXmwTkTKWCRiJCJRdIXnC2/WIyXu3gXsN7OLg1U3A2/kNHsS+LBlXA/0a7xdRKRw5vp5+mfAd4KZMm8DHzOz+wDc/WHgx2SmQe4iMxXyYwtQq4iIzNGcwt3dXwVyr2XwcNZ2B+7PY10iInIeNIFVRCSEFO4iIiGkcBcRCSGFu4hICCncRURCqGB3YjKzHmDvOf76IuBoHstZaKp3YanehVNKtUJ51LvK3Vtma1SwcD8fZtY5l9tMFQvVu7BU78IppVpB9WbTsIyISAgp3EVEQqhUw/2RQhcwT6p3YanehVNKtYLqPaUkx9xFRGRmpdpzFxGRGZRcuJvZB8xsh5ntMrPc68oXnJmtMLPnzOwNM9tqZp8O1jeZ2c/MbGfws7HQtU4ys6iZvWJmTwXLq81sU7CPvxdcDbQoBLdwfNzMtpvZNjO7ocj37WeD/wevm9ljZlZZTPvXzP6vmR0xs9ez1k25P4NLen8tqHuLmW0sknr/Kvj/sMXMfmBmDVnbHgjq3WFm7y+GerO2/YWZuZktCpbzun9LKtzNLAr8LZkbcl8K3GNmlxa2qrNMAH/h7pcC1wP3BzV+DnjW3dcBz3L2DU8K6dPAtqzlrwB/7e5rgePAJwpS1dQmb9Z+CbCBTN1FuW/NbBnwKaDD3S8HosDdFNf+/SbwgZx10+3PW4F1weNe4KELVGO2b3J2vT8DLnf3K4A3gQcAgr+7u4HLgt/530GGXEjf5Ox6MbMVwB8A+7JW53f/unvJPIAbgKezlh8AHih0XbPU/E/ALcAOoC1Y1wbsKHRtQS3LyfwB/z7wFJlbPR4FYlPt8wLXWg/sJjhWlLW+WPft5L2Fm8hcXvsp4P3Ftn+BduD12fYn8H+Ae6ZqV8h6c7b9EfCd4PkZ+QA8DdxQDPWSuaPdBmAPsGgh9m9J9dyZ/kbcRcnM2oGrgE1Aq5++O1UX0FqgsnL9DfBfgXSw3Az0uftEsFxM+zj7Zu2vmNmjZpakSPetux8Evkqmd3YY6Cdzg/li3b+TptufpfD393HgJ8HzoqzXzO4EDrr75pxNea231MK9ZJhZDfCPwGfcfSB7m2c+lgs+TcnMbgeOuPtLha5ljiZv1v6Qu18FDJEzBFMs+xYgGKu+k8yH0lIgyRRf0YtZMe3P2ZjZ58kMi36n0LVMx8yqgf8GfHGh36vUwr0kbsRtZhVkgv077v5EsLrbzNqC7W3AkULVl+XdwB1mtgf4BzJDMw8CDWY2eZeuYtrHU92sfSPFuW8B3gfsdvcedx8HniCzz4t1/06abn8W7d+fmX0UuB34UPCBBMVZ7xoyH/abg7+75cDLZraEPNdbauH+O2BdMNsgTuZgyZMFrukMZmbAN4Bt7v4/szY9CXwkeP4RMmPxBeXuD7j7cndvJ7Mvf+HuHwKeA+4KmhVFrTDjzdqLbt8G9gHXm1l18P9ist6i3L9ZptufTwIfDmZ1XA/0Zw3fFIyZfYDM0OId7j6ctelJ4G4zS5jZajIHKl8sRI2T3P01d1/s7u3B390BYGPwfzu/+/dCH1zIw8GJ28gcEX8L+Hyh65mivn9D5mvsFuDV4HEbmbHsZ4GdwM+BpkLXmlP3TcBTwfOLyPwR7AK+DyQKXV9WnVcCncH+/SHQWMz7FvjvwHbgdeDvgUQx7V/gMTLHA8aDoPnEdPuTzMH2vw3+9l4jMwuoGOrdRWasevLv7eGs9p8P6t0B3FoM9eZs38PpA6p53b86Q1VEJIRKbVhGRETmQOEuIhJCCncRkRBSuIuIhJDCXUQkhBTuIiIhpHAXEQkhhbuISAj9f7FDnMAWWpRbAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(val_loss_change_self)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compare with reference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "metadata": {},
   "outputs": [],
   "source": [
    "#scrap\n",
    "loss_change = []\n",
    "val_loss_change = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [],
   "source": [
    "#scrap\n",
    "config.set(\"is_ref\", True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "metadata": {},
   "outputs": [],
   "source": [
    "#scrap\n",
    "transformer_xl_ref = MemTransformerLM(\n",
    "    len(vocab), n_layer=config.n_layers,\n",
    "    n_head=config.n_heads, d_model=config.d_model,\n",
    "    d_head=config.d_head_inner, d_inner=config.d_ff_inner, \n",
    "    dropout=config.dropout, dropatt=config.dropouta,\n",
    "    ext_len=0, tgt_len=config.train_bptt, mem_len=config.mem_len,\n",
    "    pre_lnorm=False,\n",
    ")\n",
    "if torch.cuda.is_available(): transformer_xl_ref.cuda()\n",
    "transformer_xl_ref.apply(ref_weights_init)\n",
    "optimizer = optim.Adam(transformer_xl_ref.parameters(), lr=config.lr)\n",
    "scheduler = scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,\n",
    "                        config.max_step, eta_min=config.min_lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "381842"
      ]
     },
     "execution_count": 106,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#scrap\n",
    "num_params(transformer_xl_ref)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "| epoch   0 step     3400 | lr 0.000132 | loss  6.18 | ppl   482.748: 100%|██████████| 3522/3522 [21:27<00:00,  2.74it/s]  \n",
      "| epoch   1 step     7000 | lr 2.41e-08 | loss  6.08 | ppl   438.641: 100%|██████████| 3522/3522 [20:35<00:00,  2.85it/s]  \n"
     ]
    }
   ],
   "source": [
    "#scrap\n",
    "train(\n",
    "    transformer_xl_ref,\n",
    "    train_iter,\n",
    "    valid_iter,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'loss': 6.006121990051998, 'ppl': 405.9061560101621}"
      ]
     },
     "execution_count": 108,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#scrap\n",
    "evaluate_final(transformer_xl_ref, valid_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {},
   "outputs": [],
   "source": [
    "#scrap\n",
    "loss_change_ref = [x for x in loss_change]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 118,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x1255cad68>]"
      ]
     },
     "execution_count": 118,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAAD8CAYAAABXe05zAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzt3Xd8FHX6wPHPd3dTSICEEnoJTZogvSlFxEKxN7D3cnq2U3/YTvQ8ezlsp+jZRT3LKbazdzkQBAWld5ASek2y5fv7YyZbsrMt2c3OJs/79VJmZ2Znnk02z8x8q9JaI4QQInM40h2AEEKIxEjiFkKIDCOJWwghMowkbiGEyDCSuIUQIsNI4hZCiAwjiVsIITKMJG4hhMgwkriFECLDuFJx0KZNm+ri4uJUHFoIIWqluXPnbtVaF8Wzb0oSd3FxMXPmzEnFoYUQolZSSq2Jd18pKhFCiAwjiVsIITKMJG4hhMgwkriFECLDSOIWQogMI4lbCCEyjCRuIYTIMLZK3N8uLWHNtn3pDkMIIWwtJR1wqqR0NyOmd8ajHXDHjnRHI4QQtmWfO+6cBgC4lI+f3nwgzcEIIYR92SdxK8U8X2cABv52V5qDEUII+7JP4ga6XfKCf1lrnb5AhBDCxmyVuOu16eVfnrliWxojEUII+7JV4g724vdL0x2CEELYku0S9772owHIXvphmiMRQgh7sl3idnQYCcCJzu/THIkQQthTXIlbKXW1UmqhUuo3pdQ1KQ2o90kAjHbOT+VphBAiY8VM3Eqpg4GLgUHAIcAEpVTnVAWUk1eQqkMLIUStEM8dd3dgltZ6v9baA3wDnJSyiLLr+xfdXl/KTiOEEJkqnsS9EBiulGqilMoDxgFtK++klLpEKTVHKTWnpKSkGhE5Wa+bMsvXjR37yqt+HCGEqKViJm6t9SLgPuBT4L/AfMBrsd80rfUArfWAoqK4JiqOaI2vOQ7kblsIIazEVTmptf6X1rq/1noEsANIaSNrp9IMdCzFJ50nhRAiTLytSpqZ/7bDKN+ensqghjh+B8C9d2sqTyOEEBkp3mFd31ZKNQHcwBVa650pjMmvtPRATZxGCCEySlyJW2s9PNWBWDlQ5k7HaYUQwtZs13MSYGH7cwB49hsZr0QIISqzZeIuqd8NgE3bd6c5EiGEsB9bJu5+HZoBMKl/izRHIoQQ9mPLxO3KyQXA4ZHKSSGEqMyWiTtnwywAms66L82RCCGE/dgycTu10ZqkndqS5kiEEMJ+bJm41dArAFioi9MbiBBC2JAtEzc5DQAY6/wpzYEIIYT92DNxu3LTHYEQQtiWPRO3MyfdEQghhG3ZM3E77BmWEELYge0z5PIte9MdghBC2IrtE/dHCzamOwQhhLAV2ybu/doo5962tyzNkQghhL3YNnGXkgVAR/eyNEcihBD2YtvE3VgZZdtNFr2U5kiEEMJebJu4K8wsLU53CEIIYSu2Tdye5r0BONY5M82RCCGEvdg2cbu6jQNgiGNRmiMRQgh7sW3ipuvYdEcghBC2ZN/E3aoPADPzRqU3DiGEsBn7Jm5gi6MZbh3XRPRCCFFn2DpxexzZOH3l6Q5DCCFsxda3s9vKXZTrXekOQwghbMXWd9zbdUMaqb1s2V2a7lCEEMI2bJ24y8kiGw8XvTQn3aEIIYRt2LqopKdjFa3Udhas35HuUIQQwjZsfcfdSm03/mVbmiMRQgj7sHXirvBD7tXpDkEIIWwjIxK3EEKIAEncQgiRYTImcXt9Ot0hCCGELdg7cbce4F+cv05algghBNg9cY/8P/+iUiqNgQghhH3YO3EfdJR/8f3569MYiBBC2Ie9E3eQWTO/TXcIQghhCxmTuC9xfZDuEIQQwhZsn7jLshsDcILzxzRHIoQQ9hBX4lZKXauU+k0ptVAp9ZpSKjfVgVXwOGvsVEIIkRFiJm6lVGvgKmCA1vpgwAlMTHVgFX7rfKl/edHG3TV1WiGEsK14i0pcQD2llAvIA/5IXUihOnfq7F927lhVU6cVQgjbipm4tdYbgAeBtcBGYJfW+tNUB1ahcV6Wf3ndxhq7XgghhG3FU1TSCDge6AC0AvKVUmdZ7HeJUmqOUmpOSUlJ8iLcuda/OH+NDO8qhBDxFJWMAVZprUu01m7gHWBY5Z201tO01gO01gOKioqSF+Ehk/yLQ3dKk0AhhIgnca8Fhiil8pTR7/wIYFFqwwqSU9+/uHHHHpZv2VNjpxZCCDuKp4x7FvAW8DOwwHzPtBTHZcmJjzEPSw9KIUTdFteck1rr24HbUxxLTCc4f+Qz7wBgfLpDEUKItLF9z0mA3WMe9C8/kf1oGiMRQoj0y4jE7es0Jt0hCCGEbWRE4s7KyUl3CEIIYRsZkbhd2TJeiRBCVMiIxJ2VJYlbCCEqZETidlRK3Gu37U9TJEIIkX4ZkbhxhIb54CeL0xSIEEKkX2Yk7kq+WLwp3SEIIUTaZGTibuGWiYOFEHVXxiRurQKhXud6M42RCCFEemVM4lZZef7lQxwr0xiJEEKkV8Ykbk59wb/YRm1NXxxCCJFmmZO4uxzJnMNf9b/cU+pOYzBCCJE+mZO4Ae3M9i/fPuO3NEYihBDpk1GJO7t8h39513654xZC1E0ZlbiV1v7lw/Z8nMZIhBAifTIqcTdtEBglsN2WL9m0qzSN0QghRHpkVOJuVRhoElig9jHkni/SGI0QQqRHRiVunFn+xQGOpWkMRAgh0iezEneHUZVWaKu9hBCiVsusxF1plMAHXE+nKRAhhEifzErclZzq+jbdIQghRI3LuMStg8YsEUKIuijjErc6498hr69/85c0RSKEEOmRcYmb4sNCXr41V8bmFkLULZmXuJUKedlFSeIWQtQtmZe4K/ks58Z0hyCEEDUq4xO3EELUNZK4hRAiw9SKxH3KP39MdwhCCFFjakXi3rNWmgQKIeqOjEzc3glTQ15/kjNZJlYQQtQZGZm4nR1HhK075M5P0xCJEELUvIxM3FY+zp6M1jJaoBCi9svMxF3YPmxVd8da7p4xPw3BCCFEzcrMxO1wWq4ePPcvNRyIEELUvMxM3BGMVPMpdXvTHYYQQqRUrUrcWcrLLa//kO4whBAipWpV4gY4e2UGjl2ydRnsksGyhBDxiZm4lVJdlVLzg/7brZS6piaCi+qiLy1X99DLqnXYlVPHsemb56p1jIQ9PgAe6Vmz5xRCZKyYiVtrvURr3Udr3QfoD+wH/pPyyGJp0CLphyx1e+m44wdafHVtYm9c9S3s3ZL0eIQQwkqiRSVHACu01mtSEUxCClpbr09SU+5lX78W/84vHot+dkxyTpyIvVvgoxvAK71GhahLEk3cEwHLjKaUukQpNUcpNaekpKT6kVVRtvLg80XP3r4Du/GV7om6T5evL0vovGrnGti9EUqWBFZOKaD8v7cmdJyEfHQDzJ4GSz5O3TmEELYTd+JWSmUDxwFvWm3XWk/TWg/QWg8oKipKVnxR7btsjuX612avZsueUuPF6h9g6/KQ7Y772uK7t13S4/E90hOeGBSyLvt/jyX9PH7abPqofak7hxDCdhK54x4L/Ky13pyqYBLlqNfYcv3t7/7KhEe/Z9nmPfDCOHi8P0s3h95huzCS3euffseKP4zy6er2mHfoBNuQl+4CT3mVz7dxVxkAJXvLIu/k88If86p8DiGE/SSSuCcRoZgkXZTTOvzFOedx+YFp3PyPaf51T0+9K2w/r9fLxB8nsO7p061P8Mkt1c/m0dzbDl49pcpv37jLeKpYv31/xH18X98L00ZJ8haiFokrcSul8oEjgXdSG05inE7rru8u5eN81ye8mXOnf91D2U+F72gm5eEYSU1Xrtmc+ThsW5GcYCNZ9U2V3xoWr4XVC2cCsGb18hh72oTXA/+5PLSuQAgRIq7ErbXep7VuorXeleqAEuFwuBLaf/cDfSzXK8u1hr0lq41KxyAvP/k3Pn/9UdiyOPIbP7oh8rYDO+G1M6KcNYryffCrZTWDpV0HPADs2BelOMVOtvwGv0yHty9MdyRC2FZG95x0OBILv+G+VSGv35+7MuS1ValI/TdOhoe7hQwZe/aWBxmz+DZ4cjCU7bU+2exp1usB39wXYcmH8QXtKYefXwKfWQH54fXwzkWwbnbMt+5b+jV9Dxh33Ckt8kkmM84YDYOEqNMyOnGrBO+4AU57aqZ/edxHQwBwKI3Xp6Mmt/d+NpN+5Uf4eyK0J6/EvXUVvimN2LB0HgvW74w/4O8eghl/hoVvGa9/mQ5A6f7YDz/504/3L6d1rPKV38CGn+PadeOuAwBs2Hkg+XHs3x64AAqRwTI6caMSD/+2PwLts7NVoBXItuWzo5YZF8w1ysiX/+/92CeZcVXYqqzH++DAx663/ozeF2c79+kT4Zt7jeUDO0I2Ldu8F6fZiqXevg0A7Fk1h9I7W3Jgk0X58IGdsHsj+qUTmLNoZfj26prxZ+NpwMpLx8Ezh8d1mH3lxmcq95gJtnQ3TCmAOc9b7q+15h+fL2VrtJY1AHs2wf0d0N/eH1ccQthZZifuBItKAHo5Vluubzb9KBZ8+27E9zXcayS7zZvjaA3584uhr7//h39RA33WvRz7GABLgzrWlO+D30JHGuhU+hsA3RY+AECDF48g17efhZ++FHao8r3b8Xz3MGrlV3z4ysN89nuSW3X+/BL89EzSDqfNmgevWb9Q9p11e/ifl67h4u8O47mXXoh6vB2b1wKwdU76R2sQoroyO3En2dAfL464Ldt3AF4/kwZlGyPuE9Hnt/sXHVXtLPPFHfDmeUErtD+5hYlwjl/WGcUrt2e9zJ71v1ctjhhuvCXCXXcVrd9pNHncstu66CRn6wLyVRnjd74S9TgVd+S7DtTB4QH2b4ftKXjKEmmT+Yn7zLdq5DS99n4Piz+gd0kcRSVRdHf/FnnjbPOOddPCalUmamWd0N3eQEJvtW2m5T5x8Xrg+0fAHZ5M789Kzl13U99WAJT5WRxYX4yU/+IV9PPylBlFI8Lw+EB4tG+6oxBJlPmJu8uR/O7qke4okuOj69Gf3wFPHQr/+2eMnRXlQYmYpZ/4Fxsc2GD5jg575ga90kaLld1/wKIPEotz/ivw+RT49sGou/ke7Ip7+pn+1ytL9vorHyMyL1iF/panFok5hArZCzCaEj7UNfp5rGxaCM+NtbwgZbT9W9Mdgb1s/j3jR/PM/MQNtCrISXcISaO+f9hY+OSmqPttWPgDTdXuwIrpp/kXu2+0LqtvXhr0uKyBu4rg4e7wxpn4yuNIVqu/h/nTodzoqanLrZtCbv3iUUpevwLH3k1kLQ1cFGZOPYeX77866ikarPs65LXDnF/Un5hX/wDLPo8e5yLzqcjrhh2rzZXRWuub/jsZ1v4I63+KvW9VvXIyfJSBk33UJv8cClOt+3RUx6bHjmL9B/ck/bhWakXiLsy17kFZmx2zJbEiicrFJ5XLx8vLS5n95EWs+VuvyAd5YTy8ezmLNxkXDDXrKXj9zLDdmn53G0WLw8ucz3R9wY1Zb8CaH42WIjvXhu3TfG7oXbwyWw6piqKSF8bBqyeHva9H+QLY+EvIusXP/wmmHsLeHZtQle/YK1qrzA1UJPtqoqng8s9h9tOJvWf3H5ZFP+Uz/kL5R9Ev8NWyb5vxM4p1ocxE7n2B5SkFaIvvcaJabJtFmzn3Vvs48agViRtnVrojsB1dFjqo1pBlD1feI+SVx+Nh0JY3ae+tlEz3bIIda+CDwOQSu/YHDYy1+IOEy+P13BeMhdWx5wdVDuMCU6D3WBdhBF+Qfng0ZFP9Dd8BsGNbCQWrKw19WzFV3P+e9K9atXWf+a859sveElj6acwYU+7h7kbRj9Yhg5Jl//ws2bOfjPLGatpkXghnpnCEy3i5D6S0eEMtjlxcuP/Lhyhb8V3Kzl0VtSNxn/QMpdmN0h2Frax75IiE9vcFlZe/OWddYMNDXWFqb5gTmM7NW7lb4x2FCZ1rzVajiGXNdiMRz1u7A++OdbA5vOJWmU0+8yhlyb2HRT+w9jFr5bbgFUa8e0oomv8EAJ29Kyp1wgkk/jKP0YZ8X7mHTz+ZAQ92humn4o2nGKkmzHrKKN6qofLZvaVGC5wd+23QEuelE+DBLmk5dd63d5Lz8oS0nDuS2pG4C9tSOuqv6Y7CVtqVxhqkKbSoxLnuR//yDW/9yoL1kXtmfrm4em3At5lN88rX/cz8dTs58ckfcU49GP45LHRHnzekk1VXb/SBsrbt3MH658/xv64oHtl3oDR0xzXfU5HUPUEXIX9xilIcNfNs//pytyeuz5Vqvl/eMBZ2rou+Y5Ks2mY8gazfEXn0yRqz7n/pjsBWakfiBgqLkj8xQu0Wetec/+55/uXvsq/Gs7ckYvfw/o6l1Tpz/91GmWmXVS/D7+/xSNYTlvtt+/JRHI7olYpN1wTGfGmy4StOdn7vf+2NVGS9bbn/IvHHrkBS9yduXak+oKJd/P7tlsU1Wmu27C4NW59s63dWDAeQukTq82kWbTQrveMoAdNac8dbs0KLz2rYndM/5/hbHk/b+QG276vZz19rEjddxlDa8ah0R1ErtHWU0Pf1/nCndfHTWGfyWl00/mUaJzqty7qXrliBQ0WveG6+JHLHm3YqwpNBUHm9J47RrPxF+Pd3AIu5RV/6/CfeuP9SVpdEnw6vukrdxgVkX1kKngBKlsL813j3/f/Q/em2LFy6LGhjhIunz8u2+w7h9oVHcd/fJxvrtIaZTxgXuVRY9jkvzVxN8eQPKXUbRVuTl5zGe1m3BPbR2ijaS1UMFvr97bMaOxfUpsQN5I5McHb2OmzoUnuM2VEWtRhCxdOIr1rydOAOusC7w3KfkDFsNi8M295rzi382fUue5d8Hf7mfVvDxpkB4L0rEg01EE+cvW/XPjSKpdPOib0jwJND4N3LaL/UGBOmbPn3Md4AevcfNC015g0/0mFMI1i6bj58cjO7Xj0/vvMm6tWTyf1sMp9k38iuvcaTh3/MIZ/PKF7b9KtxcX73T4kff9uKkJZGIdaHTpW4Y9sWfr57NJs21Pzc6bUqcdNuaLojEAnKdUceKdGh3ajSKHdNcbZmydln3SGpspbePyzXx2ommKPNAa6sEuoDneC+4vD186J30beyrzx6HHpKIaXvXuN/3W7PPA764734Dm4OWOYwL1JZ3n3h+5TtgSkFuL+bCsCeeYFxXyousFsXfQvA9pIqDA1RweeDqYfAr/+23Hya72O6OtbDvtAnqh3TJsCdjcFtFlvt32bx7ug80w6H98MHiQPgWaPC3+vTfP77ZhZ//CT9yueyesY9fJH9l4TPVR21K3FH6Oot7KstkbumD970Gk1eGmW9ccFbcU+S3PmHyOOneFFo9wFW/OsC/7pdM0NHIqx8mp37y/1zmH7468aUjH/y5px1FE/+MDBKItDXUVE5a33BUmhy55uxb1pQpfN232/cVXabH96RZNtWI1Hu+cZsHugJPK0oNOzZRJuZRiMBT8QKhjj4jI5TOsZTScmmDSxaEqiEb7SposhNB/0fmPcqbFlkfIadQZ3WLLjKjEp57Ys8f+y0rxYz45WpLNscOFYnRzUuVFWQ+IDWQtjB2xeCLwllvRoWfPICvde97V916N7Qttv7y8rR+8upaPT4fw8+QeuyFfz1yos5+u1RuJxVTFJPj4Tjn4AWB4esfuSzpUz/YjYnORZQ/vXvZCd6XHcpPBWj6STw8szVLNm8h7tOCHS6yvKVgoJsb6BXbK8yYyz1cp9xY+TxhCc1hTbuyP2qPtaOx+vFBXi8mmg9NA7+4DjL9Zt27acFsH7HAdoCvGcWmdyymUUzHiL2Twa0z4dyWNSv/PtcJqxawuXZi5i3q3Pabn0lcYuMpfduqXYZuMJHw43ROwKd/vSPrD2Qw+pc4/XTvimQBb5n3sClgpJ20BPfgg8eR/3yOv6U7PPBhuCxYoCN841RH88MnYpu29dP8lOueedsUdQc80HDG6GFw+tnwsRX/S9ve89oNx+cuJ0qKOEGF0Wt/R9KNTVWmz/1vaVuGkYIIaynagK8WuPCmAXJ4y5n7eK5dEzg/bv2l9MC2Fte6QLz9+bktjnP/7LM4yXHFUjOcx8/h/7mss/nxWF12fj9XeNiQPATUM2rXUUlAP1TVCkibMe3suoTLQdoijdEH/HxOs8zrM4NnyNUe0OLSFRQous15xYOdgcVV8x9Hv4V3iJlm0UzsruyrCeNqJC9d13IJNbbP3vIGJfF5Ik0F2ql3oFnOT/jn1mPhF9QTL2+DhQf8dzR5K76AoDmGGXHa7YHmiUqNLNWBcqUK34SXp/mw183JjQDkzZb+uQoNx89chkd3z4m7vcGn9tq2OPSPYE6k8lvhLaO6r81UB/g9dqj7X4ktS9xH/uP2PuIWsG5ovpjaMQzt+UJzh8t14fcnQJFK98BQJdalKPutq74XLcjtF34gdLY7YE7fnsNPNbP/7rxD3dS9kFgcuols/4b8b17dm2Hz/4Ka2dxV9bzjHX+RPl/4+u8Vvhl6OBYwR9fobnznblBr4Gty5n12t8Z/043vvz2W8tjLtywi59XbMT93LH8/Py1rPrrQSFJs9++xC/OFRcJq6exnju/9C+PWfq3yMew+RR3tS9xW/il3uB0hyBsqo1K3pCn9TYZFXub5oVPBD3rG+uxMPocmAVf3mWMIQ7s3VW17uw58wJ36T1/eyjifsumnQM/TIXnAn0estfFbvpnLbjXKXyYc3Potsf7M2yZMTtT3rqv0LvW4/4+dNyTCY99zz3PTidr7bf0W/McHRybKT8QKF+vyu9HmU8fVnfcjVTg2IP4NeIxfGbl5EszV8c8n29fzQ+bWzsT9+mhTa0OHPH3NAUi6pLt+8rZV+ah5SeXhW0b7IhQfAHw7QP43rkUgCxXaocoblCapNYPuzYwZHVggKtDnaHjzKhKRSPa52XLM6eQ9fmtxoiDS4yngmeyHuTNnDtD9vV6qtdK56DPjeLSHu6F/iGIrUSrH/GarWKafBR5VqwKw/bWbOcbqK2Ju/uxIS+HDBjI71kHR9hZiOTZ/vtXVXqf4/f/4F7wH+55I7VDqO53J6kI4IPond1yqTQEgM+HZ1+gfHnXT9OhfB9HOn8Oe++OvdUb1MuhA0Ut+l9HRtxPoekw2bp+o6I54Hjn7GrFkiq1M3EDu1seCsAH42YZrxv1DNtnSnZy50cUdVt7xxbavndKld+f9fZ53Lcl/G49mSLOU5pkqlLTlx+WbfHfxQLsWj0f7m5l+d61L1ehx2OkOCx6ulZoovawKvcsy22zVm5NzdACSVJrE3fDi99n89XrmDCom+X2pbm9OOPCayy3CVFb9XGsiL1TXGLV6oZuv9D1Ee0cJf7X7TyRu4kf4ZxXncCS4sV3ZpB/T5N0hxFRrU3cOJw0bxSplSnsczTkoOYNajAgIWqPWM37WunQStbGynqaO7t6JbtmpiCrqtqbuGOoqUdGIWojtbzmK+REQN1J3GavtvUuY9xuj7MeAB82Pi9dEeHRdefHL4RInjqXOXY0MToudDjCaDI0/qqpVTrOmLLqD4sqd/1CiKqoM4k722l81PLCTjBlF0V9xkfc96MjYk8Qu0Jb14gnouqjOQgh6rI6k7h7DTHG0u0zIPbYYOOGD+a7DtFbnNxzYu+kxBXNQ+6qNy0TQtRedWZ0wKzep0C7wTgL28beGWjUoS+sirz99EHt4OPw9cPLHsGJj/Ocn3CeK/adezSliQ/oKYSoA+rMHTcAcSTtvTktAHBnF1hu362NSk0VYdKGdbo5q3VLpjutxwoO9rPuGnX7i96jYx5DCFH31K3EHcHOC2f6l9873BhDoUvf4VyfEz5q2mFljzK67MGw9cNKH2Vc2d188ZeRfH7dCD6942we9p4ett997okcX3Yn48ru5r1ugcGAduj6IfuN9jzKBSOtOw8lqkzXmQcrIeoESdxAYdse/uVT+xvNBevnuHjwpvB55Nq1bslKi4rJi48dQfe+h9GpqD6dmxkde57wnhC23z+9x/GL7sxZJxzLnacP8a+/zn15yH5tOnbn0M5NuNF9MV96+1Ttg5mOLZdBtoSoTeRWzLSvxWDyN80i2xX9WvbvS4dazjF4/qEdwtY9cUZfeCt03YwrD6VNozwa54eWX/cfNBzmP+B/fedxPSlums/Z3sOZ6evBaOd89upc6qtKg/eYnvaM51JX+HCii3xtadKhD1gPBy2EyEByx23KP/9tuOyHmBMO52W7aFlQL2Td2DLr7rHHHNwybF3vNoUhSfsO99l87+3J0L69WOQLlMEXN833L6/TzSkunc7Xvsh33q97R1uu/8Q3kGvGdOFHbw/L7UKIzBNX4lZKFSql3lJKLVZKLVJKDU11YDUup0HYpK3BRpY9zNnlky23rc/pVOXTPu8dy1nuW+jXrhHRRgju0bIhDXIDD0g7dX7I9lW6Ja96juAG9yUh67VW5Oe4eNkbeXhLIURmifeOeyrwX611N+AQYFHqQrKnqX86mfEnnGm9bWLVy6A/umo4j07qi1KKfdlNI+939XDycwOTl1buddm6sB63eC7kTe+osPd2bdFAmhYKUYvETNxKqQJgBPAvAK11udZ6Z6oDs5s+bQuZOKid5bb+7RtX+bg9WjXkuEOMys6dKrwJ4ksXDPJfGNwq8uwoFw0PL2MHGNqpCVlOByNGHFHlGONxUtkU/u0ZGXH72eWTeV+NSmkMQtQV8dxxdwBKgOeVUvOUUs8qpfJjvakuKaiXFXunOFiNXTLioCKO79MagLeKrmSzLmRo6WNcq27kDx24YJw7tJinzurPyrvHWR5bNWhBcen0pMRpZWDXdtzouTTi9uW+1jQvyKvSsW90x54+Soi6JJ7E7QL6Af/UWvcF9gFhhb1KqUuUUnOUUnNKSkoqbxZxqGjRcrv7XMvtpc58Bpc9yUaaME91Z1jZ4/5tDofimINb4HBYl5OfOaQ9Nxzdlemew5MfOHDqgDb8OuUolroOstw+uHf3mBW/kcz1WR8z2FfeQ6p0bCEyUTyJez2wXms9y3z9FkYiD6G1nqa1HqC1HlBUVJTMGOuMvGwnAHt1Pcvth3dtBkBxkzx/DtyiC+M6dpbTwRWHd+bJnSicAAAaOUlEQVRmT9XvXj/yDvIvW7VSaZib5X9quNl9Yci2m4/tVeXz9mgZe8KLPfnFVT5+sDc9I5JyHCFSKWbi1lpvAtYppSr6Zx8B/J7SqGqhrboht7rPr9YxTunfhhlXHspn1wXKku9q+TgXl18X8T17s5uFvL7/5MDgWEeV3ZfQ+RvnBYqEbvNU+izmjCj/aHgj0z2jeaNSJWmO0xn3eVYWjQl5fePRsXuQNuw2kvU6cuVuvE69y3ryWCHiMdfXpUbOE2+rkj8DryqlfgX6AHenLqTMEVzGHMk/Bn3L/3X9lHsP/oAfGoX3pLSiVOQBX3u3KSTL6eDmsd1RCu6/cBwP3nZzyD6HlwW60v/W/PiQbacNDLQV36QbxRVPhdmNY4+/cumJY3ijxV/4v3GhkzMX5GUxu9GxUd/7gXcwAFtbBu56zyi/Oa7hbwvzsvm4XXomf77efSlrO52RlnMLe8kaUTPz2MbVc1JrPR8YkOJYbOmwsqk0YL/VQICMKPsHClgW5f3XjEuk7NUoZhjSMfYkpacNbOtPwrlZoXezJTqodUqUcuWCwiZg3RHTUusBE+D9/wNgVNdm/LayPT0doZO+9mlbyHtXmkPnfhn6/nV5PSkunc7qXOskt1/nGgtB8xn+6DuYeEcuv+iss+Ae67b2AN94ezPS+Wtcx0qE1go1+FJYkbrKX5EZeo+xnjU+2aTnZAzrdRGLdHvLbR9eM5q3rojcBC5RFempXeOqtb6ocNmIjnHt98FVI3nWMzbu457cvw0rfEZv0LOHtGd8+T1BvT3Dk+t2c+CsikG5hnWOfEEq0Q1p1Si8bP/RSX3jjk/lRC8L395sCN8Mejru48V9XpnISNQwSdwxXH/UQZzUt7Xltq4tGnBI2/gqB2vS2N6BQbB0jLvVd7zD4zrmkNLHAMh2GllKKXj5wkERh7cF8GI8Cbx0udGG/Pg+rZl765iw/WZ6e3Bk2QP+WYrAGNf8pLIpHHdIK2JMKB63lgW5tOvUPeL268ovA2By/t+4tTD+8v+T+7WpdmypsMhn3e8g2e5yW3dME6kjiTuGK0d34eHTqzc6X7yU/9/qZarmDQMddZrkx9djcq2vKGLl6ZfePmyi8t2yYniXopBkW1lFUs91BYpymtQP70T0i+7IThqwJdd4stlfrwXrdHN+1rGbASZCA7764ePHAHi14h2fUbZ+7w1Xcdc1l8V9XGcCFa816azym6Jun++r+lANweb5OiflONU12X1RukOoMZK4bWRFvnGB2F0/OX9QAGcMti7mqWwveXgjfB0ucN8Y8X3/yTEqPz31w+864y1BeMRzCs+cM4DvmpzOqWV/ZXPRMO4+sRcTB0YuhknEX8128RsL+uHIyadPaXKLSw4+unqthVLli+tiFOOp5Pz5n3P66VxbfnnsHZMgWnt9d8PUPmEMVy+k9PiJkMRtIz8VjGVw6eNsa5S8+SydFh1yzi6fzIXlf8ER52//zcuGsvQu67Lw1e1OpLh0OnkNY1eoWlnnK+LEgZ04skdztHLwkzaa/p0xuB33mk0X43kCqdjjY+/AsG0veY+mU+nLDBs1lnaN8+hfHN5scIGrJ9PO7h+6clT0O9YKefWs292nW2F+5CESAJbmh3XHqJJje7fiC19yjhXJVw1PoLh0Oh/6hkTc56Lh8d3w/OrrwAad+Pf1u9tPTPg9qSKJ20aUQ7GZqo97UqGiTHhPhI483/l684WvPw1ys7jh6MD0aSMOCu04NdvXlcvc1zCwuHHYOOXKvCDcd3IvXrlwMO2ahFeoxnOfrJSmWQMjwbRoaLQqKagXvXinVIcPMVDxmf9lUdk6aVBburVqRIuCXJwOxUPnhHey6XPjfzmqZ4vQlaPCW6hs1/WZ5YtvZqLXIvRSfcxzAqWO6iX7tb44OrnlNIy6ueO4yO3/E+FwKLq3b8kuXb1K9eiif5u0MxtHTv2o+wQrC/oObdOxO3jZjSTuOq6iBUuOy0ETs4ONWzt523sYp5Xfzie+QdHeTl62i8O6VL3jy3xfJ64cbXRauHJ0F6ZO7MPRPZuH7OPOC319cvkUf0UiwF6dy46mxt1y8HgvC33FzPT24J6TevPhVYFK2Pwci1awMVqkAGxVjXixzxtcVh5oq7uzoljLopL2Js/FDCp9Imz94ROvZV5R9Db9/yB6u/BNjfpH3FZcOt0Yl8YV/QJY3CL097bSF7hwPeGJ3WY/2LRzh3BI2bMJvScR2S4HvVoX0K+d9Y1N2eU/xV30U7n57AfeyHfxlc30hfcYrugl7ElxUU0wSdw2ctXoLvRv34ijK9/5JaoqzTCC8s6b3pH8xf0ny0NVt+I02ISyu3gk/1r/3Xy2y8HxfVqHtVTx5RTSsfQVfyuJq8d09VckAhxc9hyOfCMJjesVaFEzofxuJrlvDTtvVpQK1TDXL+PXnkYZ/6oGA7jq+EN58cqxLHcZFXIlox6I+NaWBblsIbyTU7t2xTHL/w8fHr182uOM7+52oSvyBBoqN3CxGl72CJe4jTvwNb5m9GljtJZ6zHNCxLFzQg8WVzhR/RQ0Jk3lXrAOpXj/z4dx2tkRytLrNY4Zwmllt5nHCv0OvJgdf+epGd5hYeu+8xnj+Ndks1BJ3DbSrkkeb18+LGmjDUby1Fn9LdtHW6XkSyq1CV/sNO6OdVb8j6WVDSt9lKPK7mOh7ogzJ3YC0oAPh/9uun3j0GKG6486iJFmMc+wzoE/+AdO6V2tsdIBqN+MsiyzQ5Ny4HQoercppE0L4+LaqWXkoq1zhhZbrm/YMLwJaaLFDJEGE7vffToTegdazriiJJMm9XPwaCMF/P2s0TwyKVA/4DIT2wGdzYveo2PG0yDHxaiu1RujaJkv0Oz2U691fz9XXgFl169ip6vSU5724c5rZvmeCn871fh8mtDv+kfXxT/wmtXUhg+dalSYRviVpIQk7jromINb+McAD/4Kby00KgO/8BlJffW947l5XGi75/uzr+DYsrvw5ocWXyTirGMO5eA+Q7hpbDeeOy+8MrGydo3zaNYgxx9p5bbpV47uEpTIAn89pw5o6x8St1q0zzxv4M8ld+LzMGYKjlahF4bgYQTOHNKO8b2tmx9WtsURPemEjHU+eV3E/RZldWfqxL4s/tsxcZ3XY7a1H1jchEZ5RtGKy5l4BnI4FC+cH71YLRHdWoQWXQU/6eXUb0zhrStCtmvlxFu/BQNLn6BEW5ftq6Clncq4eG7uPJHc+vEP/XBk9/Df06CDjb8RNbzmhlyQxF0bmc9sic56s7NhVzqVvswXvsjlp+UqhwW6Y9SON7H8aVRnHj69D5eO7ESbRrHvNHOznMy+ZYy/bDLL6eDxM6x7VKbicXV7Q+MPc3GDoBn76jeDw64NOqHx7wEd+Jk3zM3iiTOq2Nqi0ue40RM0JV1uQ2Y2O83ybVeM6ozTofw/q1hFW5O8d/K0Zzy46uFt2Ib5vo48kH1FpDDoW/pU3B8hUZ1aBlp6ZMWYtBvgKc8EAL7w9sXnMp7CSrCeAvDy8qtDfhLOiS/xYpNraTLpnwnFmG11UcvOgym7oH8cRUpJIom7FtI5DbjffTrnc0dC72uSn40XJ+cObc+n11oPb1oxm31Rg+hNzQBucV7Ht95e+PKqP2ofBCoetdZMCOodGuxAYXxNwg4qfTGkZUE0I0eM5qbunzFuYuxOOfGWny9ucwqbow3JWynfXnVEaGekRq278pwn/K66YmhgqwPtILwCtn6H/tzjOROn00HjhvmcUH4XXYdO8L+r4uNM84xnp87n+GHRm6pO84yPuj0aV4fg8uPYV+BSjO/gQt0BX4x6nY99g0Ne9zyoC+f+eYq/SCiW33OMJyudrG681SSJu5Z60ns8q7FObtYUR/dswRNn9OO2CT04qLl1K4tzhxWz+t7x1LdqmVGJu+0wznHfRE52cua79P/JRPnj8WWHT/9mpZwsPHF+/XOznNxz+iCaWvT69J+3nnG3+LVrGEt8bbim/E9RjznxmMN5d3RgFK5YaeraMV3YnN+VlYf8BTDKz5vkhf8OKv9kXs037gJXD/07K2kbtv9TZ/Xno6uGk+1y0CA3i9X3jufyUYGL3+AOTXjxgkG83ugS+pQ9w5TjeoYdI9jKvpEH+YrHby0itbYJ/50P7hCoX3AoRedm9f1NSyv7/LqR1ao/9aqKJxiff92yCW/x00HJaVKZKEncdZw717gbnufqg1KK8b1bxn0XEstjk/oala15yals/aGZUfvvahp5EK10DfiUVb8xvUuf4eOiizi6/H7e9R1mud9O8643N8vJpSODng4qx13ptVKK5jfMpuOJfwWMMuXgoQ0i2dF8KMWl0znQ+xyec5wc2HDuB4DRNLJHq/AyYbfDaFPvc+Yw8qAi3rviUH6YPBqIPnnHvUHjvf/Y8pyY8QXz1CtiR37gd+vT8f0ym+Rnk5/jIi/bxexbxlhWEnZuVvXKdAhcNrY3CLTh7zLgSAaecXu1jltVkrhrsXi+9p78FhxaOpXn6yW/fC4/x0X/9omN+R3NaedfwzdnLKe4TeQKx6YxegtWeO3iIbgxLigVd1PV0aqwHvecMZwnzjJaQ4zvZV0p6VZxXsTieiIP36lymfY9J/XiyTP70b1lQyYeGlTc0iH64GJzWp3JQ+5T+LXV6QA0yM2idaFRjuyL89410XqQ/iMmBN4L9Ct7isc9x/tfW5wAgE4xkvKuiTMA8GYbF6g1zuKE4gJwm/VFu5umtodovCRx10L1zIqpq8fENxvHBorwJSF5pVpetsvf7C8Sqx6cVoZ2asLKekb722WDkzMvyPjeLSnMy2bOrWN4JMLAZPuxjq8qDwrxvKd+jotx5kWkqEFu3Md2ZtfjMe9JOLPDL4QqqKPLVh2laCqBxP2hdxAupyNQhqyMp5NNcUxWEktBN6NFTtP2PTit7DbWDb0z6v4rdehFd5GvLc81uRaAE/sloZVSEkjiroWynA5W3zuei4bHNy43JKX/hG3s1bnc7Z4Uc7+Kyk5PdvSu4YlqWj/Hsr3vre7z+b96kR6tjYRVps1y6zh+IfM6hXdG0Un6TV48vCOXjezEecOKo58janJ2MNsVef6VSeW3+JdvzzEHMguqv/jyLyPp3y7+CtxYihrk8MrfruO8keFD+44tu4dXPMbww5vbHMP/Hfwdn3mN1lUPe05lt8OII6HOWylkjyhE2rQ1m+OdMzS+UQTt5A+X9TjYB5c9xzRv9GnSIJBzaqqhwCveI9mVY12E4jCzkD8pxhHTBUeGJ8Vk9Wytl+1k8thuYd3DIf6Lw4IWJ9L1mhkRt18SdGPxedhIhoqORfXxdh3Pbp3H/5qcFNc5Ae50/CniWOTZLodlEc7T159H63YV8SjuO6U3TYMqOkd3i97OvqbFNXWZqL0K8oyWBJmo6dXfsWfHBotGbvZ01RFdOLW/9cVGVU7ccbC6q48uOXfjWqm4LixHDexJQf38iNubFeSx//S38KydTWGedcujgmbt6F32LDe1thjYK8LHWVZwKGM39oo4RZ6Vdk3y+MN/kQr9cOcOK+ZQiyePdJLELTJWdoPGZDeofhloEodfieq6I8MnhthBQxqx2x+EP3EHJaUSXUCkkv336x3HsQcCd7U6Wl1Fkprc5OdkxTVXactKnasmlt/K69l3BYWjyOt+JHQ/0r+u4ldREemY7s144fyBDO8Sf3f6F84fyPfLt1Ly67HoFr2p6r1yxcU0P9tZrQ5nqSBFJaLWuXV8dz66KvaUbBV/ir4UZ+5bmz3BFHf0pnEOHfmOe1BZ+AiDFY649nn/8oueI9lTlPpWDw1PesS//Loj8tOao9Joff+rNLKesmi3V8/sQFTxr1KKUV2bWY4rH0mzhrmc1K8NRee9QrNjIk8CEh6v9e9AJWnCiWSSO25R6yRSKVsTbrtkEvvKrLuoV6i4u7Nqahet+CQvO/AnfLvnfF6rgdZB6qDAoFNv1juVKyPtFyPhKYv7xn4dmsMS6Na26mPhVFWnxjmwGoqbm01YbdJL0or9LiVC1LQU/33muJw0jjH35+vZRuXbnRhjkuxpdLB/2y+3HRX3uaLfmSb/cf/lCwZH3OaoNMXSf/40jEnlt7DaZyZli+IH54DzYcSNOEbU3IBNFZqYrSVbNa5Ua2KzYhKQxC3qsD1Oo4mXx5X+qccK+xxPcel0PD1Oorh0OvWbtmGBrxiAgjgnfIYYiTsFCShau/nK5cJ92zVipq9nYMwUqztyVzaMvsUYuCmGXLNyNifhStoIsszvgTlzkMo2KlaTNWRDMklRiaiz3m91FR/91pqjWx6a7lA4uX9rLjx6IF6tOW1AW/q0LeTP7R9mw/JfeSeB40RLYgcaxzflWrIoi0lN/33pUHJfVuCDuCc9jaDXiTew/qXf6X1S/OXYUR12LbhyoZ/Ri/igC6bx+0cP0/2w0PFTLii/HjcuXk7OWatEEreos249YSAvNi9i5EHpa6NbevQD7P7qrzRq0gzlUDhQDO1kDFj10LmHs788diVrhZvHdaOnxbgjfq563Ok+mz6O5SQ2MVnyDOrQmMfyxtJ97zLcDavXd8DVoCltrng/SZFh3HEPDwwaVa+gKT0mhfeqveHPV+OqyVkTLEjiFnVWQV4WVx0R37AAqdJy6EQYOtFyW7bLQXaMeSODXTIi+pC2GnjOOxa8pC1xA2zocCrFPw3hq0Y1XwGZDN1bJrenbVVIGbcQdUTSyoJjWNH+9KjbpxzXk7cuG0qHppE754jo5I5biDqiRw3dKXY672ng6Yjbc7OcDChOQsepOkwStxB1hFKKk/u1Ycnm3ak+UWqPLyRxC5Hpbnefy6GOhcTT2vuh0w5JeTwi9SRxC5HhXvQezYveo1md7kCi8GqF/Ud8zxySuIUQKTW+7O9s1QXMSncgtYgkbiEy3NfXj+KPXQfSHUZEv+kO6Q6h1pHELUSGK26aT7E0ratTJHELIVLqiG7N/L1BRXJI4hZCpNS/zhuY7hBqHek5KYQQGUYStxBCZJi4ikqUUquBPYAX8Gitw6eXFkLUGaXHTcNRvyn2G6m6bkikjPtwrfXWlEUihMgYuf2iDyQlUkuKSoQQIsPEm7g18KlSaq5S6hKrHZRSlyil5iil5pSUlCQvQiGEECHiTdyHaa37AWOBK5RSIyrvoLWeprUeoLUeUFRUlNQghRBCBMSVuLXWG8x/twD/AQalMighhBCRxUzcSql8pVSDimXgKGBhqgMTQghhLZ5WJc2B/yhjcHQXMF1r/d+URiWEECKimIlba70SkNHXhRDCJqQ5oBBCZBiltU7+QZUqAdZU8e1NgUzp6COxpkYmxQqZFa/EmhrJiLW91jquJnkpSdzVoZSakyld6iXW1MikWCGz4pVYU6OmY5WiEiGEyDCSuIUQIsPYMXFPS3cACZBYUyOTYoXMildiTY0ajdV2ZdxCCCGis+MdtxBCiChsk7iVUscopZYopZYrpSanMY7nlFJblFILg9Y1Vkp9ppRaZv7byFyvlFKPmjH/qpTqF/Sec839lymlzk1BnG2VUl8ppX5XSv2mlLrarrGa58hVSs1WSv1ixnuHub6DUmqWGdcbSqlsc32O+Xq5ub046Fg3meuXKKWOTkW85nmcSql5SqkP7ByrUmq1UmqBUmq+UmqOuc6u34NCpdRbSqnFSqlFSqmhNo61q/kzrfhvt1LqGlvEq7VO+3+AE1gBdASygV+AHmmKZQTQD1gYtO5+YLK5PBm4z1weB3wMKGAIMMtc3xhYaf7byFxulOQ4WwL9zOUGwFKghx1jNc+jgPrmchYwy4zj38BEc/1TwOXm8p+Ap8zlicAb5nIP8/uRA3QwvzfOFH0XrgOmAx+Yr20ZK7AaaFppnV2/By8CF5nL2UChXWOtFLcT2AS0t0O8KfugCf5QhgKfBL2+CbgpjfEUE5q4lwAtzeWWwBJz+WlgUuX9gEnA00HrQ/ZLUczvAUdmSKx5wM/AYIxOC67K3wPgE2Couewy91OVvxvB+yU5xjbAF8Bo4APz3HaNdTXhidt23wOgAFiFWbdm51gtYj8K+MEu8dqlqKQ1sC7o9XpznV0011pvNJc3YQy8BZHjrtHPYz6a98W4i7VtrGbRw3xgC/AZxh3oTq21x+Lc/rjM7buAJjUY7z+AGwGf+bqJjWO1mujEjt+DDkAJ8LxZBPWsMkYctWOslU0EXjOX0x6vXRJ3xtDGJdM2TXGUUvWBt4FrtNa7g7fZLVattVdr3QfjbnYQ0C3NIVlSSk0Atmit56Y7ljhFnejERt8DF0Yx5D+11n2BfRhFDX42itXPrMs4Dniz8rZ0xWuXxL0BaBv0uo25zi42K6VaApj/bjHXR4q7Rj6PUioLI2m/qrV+x86xBtNa7wS+wihuKFRKVYxSGXxuf1zm9gJgWw3FeyhwnFJqNfA6RnHJVJvGirae6MSO34P1wHqt9Szz9VsYidyOsQYbC/ystd5svk57vHZJ3D8BXcxa+2yMx5IZaY4p2Aygoib4XIzy5Ir155i1yUOAXeYj1CfAUUqpRmaN81HmuqRRSingX8AirfXDdo7VjLdIKVVoLtfDKI9fhJHAT4kQb8XnOAX40ry7mQFMNFtydAC6ALOTGavW+iatdRutdTHGd/FLrfWZdoxVRZ7oxHbfA631JmCdUqqrueoI4Hc7xlrJJALFJBVxpTfeVBboJ1j4Pw6jZcQK4JY0xvEasBFwY9whXIhRXvkFsAz4HGhs7quAJ8yYFwADgo5zAbDc/O/8FMR5GMYj2q/AfPO/cXaM1TxHb2CeGe9C4K/m+o4YyWw5xqNojrk+13y93NzeMehYt5ifYwkwNsXfh1EEWpXYLlYzpl/M/36r+Nux8fegDzDH/B68i9HKwpaxmufJx3h6Kghal/Z4peekEEJkGLsUlQghhIiTJG4hhMgwkriFECLDSOIWQogMI4lbCCEyjCRuIYTIMJK4hRAiw0jiFkKIDPP/Puu1QrMha2sAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "#scrap\n",
    "plt.plot(loss_change_self)\n",
    "plt.plot(loss_change_ref)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {},
   "outputs": [],
   "source": [
    "#scrap\n",
    "val_loss_change_ref = [x for x in val_loss_change]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x1255ba780>]"
      ]
     },
     "execution_count": 117,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzt3XmUXOV55/HvU/vS+95qLS0hgUACYSxjwIEhwcRAGDhJ7BgnOV6yEM94YjvLzLEnJ3biOfNHzvjEsZOMGcZrMo6T2CYxcQzBxk6CF7DFZlogoQ0trZa61S31vtTyzB+3JEqtFuqGbtX2+5xTp27d+3bV01fqX99+73vva+6OiIhUl1CpCxARkeWncBcRqUIKdxGRKqRwFxGpQgp3EZEqpHAXEalCCncRkSqkcBcRqUIKdxGRKhQp1Qe3tbV5b29vqT5eRKQiPfnkkyfcvf1C7UoW7r29vezYsaNUHy8iUpHM7OBi2qlbRkSkCincRUSqkMJdRKQKKdxFRKqQwl1EpAop3EVEqpDCXUSkClVeuB9/Hh79HzA5XOpKRETK1qLC3cw+YGZ9ZrbTzD64wHYzs0+Z2V4z+4mZXbP8pRYM74HHPg7jR1fsI0REKt0Fw93MtgK/CVwLbAPuNLON85rdDmwqPO4FPr3Mdb4sXh88z06s2EeIiFS6xRy5Xw484e5T7p4F/g34hXlt7gb+ygOPA01m1r3MtQLQPxXcMWFy7ORKvL2ISFVYTLj3ATeaWauZpYA7gDXz2vQAh4teHymsO4uZ3WtmO8xsx9DQ0KsqeN+YAXDqlPrcRUTO54Lh7u4vAH8CPAI8DDwD5F7Nh7n7/e6+3d23t7df8KZmC4qlGwHITI+9qq8XEakFizqh6u6fdffXu/tNwEngxXlN+jn7aH51Yd2yixfCPTulcBcROZ/FjpbpKDyvJehv/5t5TR4E3lkYNXMdMOruA8taaUG6Lgj3/Mz4Sry9iEhVWOz93L9mZq1ABnifu58ys/cCuPt9wDcJ+uL3AlPAe1aiWIB0Ms6EJ8jP6MhdROR8FhXu7n7jAuvuK1p24H3LWNd51cUjTJKAOQ2FFBE5n4q7QjUdCzPhSUJz6pYRETmfigv3SDjElCUJZ3TkLiJyPhUX7gDToTThzGSpyxARKVsVGe5zoRTRrMJdROR8KjPcwyliualSlyEiUrYqMtwzkToSeR25i4icT0WGey6aJuHT4F7qUkREylJFhns+Vk+ULGRnS12KiEhZqshw91hdsDCrse4iIgupyHA/M2GHLmQSEVlQRYZ7KN4AQEZ3hhQRWVBFhns4GRy5z0ycKnElIiLlqTLDPRUcuc9Mjpa4EhGR8lSR4R5PBvd0n5tSuIuILKQiw/3MVHvqcxcRWVBFhnuivgmA7LSO3EVEFlKR4Z5KN5B3Iz+toZAiIgupyHCvS0SZIIHrIiYRkQVVZLin4xEmSeoKVRGR86jIcK+LR5jwJKZ5VEVEFlSR4R4OmabaExF5BRUZ7gAzoRQRzcYkIrKgig332VCaaFZH7iIiC1lUuJvZ75jZTjPrM7Mvm1li3vZ3m9mQmT1TePzGypT7skxEU+2JiJzPBcPdzHqA9wPb3X0rEAbuWaDp37n71YXHZ5a5znMEU+0p3EVEFrLYbpkIkDSzCJACjq5cSYuTi6RJ+pSm2hMRWcAFw93d+4GPA4eAAWDU3R9ZoOkvmtlPzOyrZrZmmes8Rz5WT5g8ZKZX+qNERCrOYrplmoG7gfXAKiBtZr86r9k/Ab3ufhXwLeCL53mve81sh5ntGBoaek2Fe7ww1Z7GuouInGMx3TJvBg64+5C7Z4AHgBuKG7j7sLufnq36M8DrF3ojd7/f3be7+/b29vbXUjd2eqo9XaUqInKOxYT7IeA6M0uZmQG3AC8UNzCz7qKXd83fvhIsEYR7Rvd0FxE5R+RCDdz9CTP7KvAUkAWeBu43s48BO9z9QeD9ZnZXYfsI8O6VKzkQSRRmY5oYJbrSHyYiUmEuGO4A7v5R4KPzVn+kaPuHgQ8vY10XFE6enmrvFPUX84NFRCpAxV6hGi/MxjQ3qUmyRUTmq9hwj9a3AZCbGC5xJSIi5adiwz1R30LWQ+SnTpS6FBGRslOx4V6XiHGSOpgcKXUpIiJlp2LDvTEVZcQbsGkduYuIzFex4d6cinGSesLTJ0tdiohI2anYcI+GQ4yHGojPqVtGRGS+ig13gKlIM4mMhkKKiMxX0eE+F2sinRuDfL7UpYiIlJWKDvdsooUQeZjR0buISLGKDndPtQYLU7qQSUSkWEWHeygdXKXqkxoOKSJSrKLD/fQtCGbGBktciYhIeanocI83dgAwfVLhLiJSrKLDPd0UhPvM2Gubsk9EpNpUdLg3NTYy5XGy4wp3EZFiFR3uLekYI9ST1wlVEZGzVH64ez2had2CQESkWEWHe108winqicwo3EVEilV0uJsZk5Em4nO6M6SISLGKDneA6WgTqexoqcsQESkrFR/umXgLSZ+C7GypSxERKRsVH+75REuwoPvLiIicsahwN7PfMbOdZtZnZl82s8S87XEz+zsz22tmT5hZ70oUu6C0bh4mIjLfBcPdzHqA9wPb3X0rEAbumdfs14GT7r4R+ATwJ8td6PmE69oByOlCJhGRMxbbLRMBkmYWAVLA0Xnb7wa+WFj+KnCLmdnylPjK4g1BuE+e0v1lREROu2C4u3s/8HHgEDAAjLr7I/Oa9QCHC+2zwCjQurylLix5+v4yowp3EZHTFtMt00xwZL4eWAWkzexXX82Hmdm9ZrbDzHYMDS1PN0pdYzt5N2Z1218RkTMW0y3zZuCAuw+5ewZ4ALhhXpt+YA1AoeumETjnDKe73+/u2919e3t7+2urvKClIckI9fi4wl1E5LTFhPsh4DozSxX60W8BXpjX5kHgXYXltwLfcXdfvjLPryUVY8ibsMnjF+PjREQqwmL63J8gOEn6FPBc4WvuN7OPmdldhWafBVrNbC/wu8CHVqjeczSnYxz3ZmJTOnIXETktsphG7v5R4KPzVn+kaPsM8LZlrGvRouEQI+EWkrPPleLjRUTKUsVfoQowHWsnnRmBfK7UpYiIlIWqCPe5VAdh8qBJO0REgCoJd9KdwfPEsdLWISJSJqoi3ENN3QD42ECJKxERKQ9VEe6J5h4Apob7S1yJiEh5qIpwr2tbBcD0iMJdRASqJNw7mho44Q3MnZp/PzMRkdpUHeFeH2fIm/BxnVAVEYGqCfcEx72ZyJRuQSAiAlUS7slYmJPhZpIzugWBiAhUSbgDTMU6dJWqiEhB1YS7rlIVEXlZ1YS76ypVEZEzqibcw7pKVUTkjKoJ97iuUhUROaNqwr1eV6mKiJxRNeHe0dTAsNfrKlUREaop3OvjDHozjKvPXUSkisI9wTFvJjKp0TIiIlUT7slYmBPhdlIzCncRkaoJd4CJeCd12VOQmSl1KSIiJVVV4T6X6goWxjRiRkRqW1WFuzesDhYU7iJS46oq3GMtQbhnTx4pcSUiIqV1wXA3s8vM7Jmix5iZfXBem5vNbLSozUdWruTzq+tYB8Dk0MFSfLyISNmIXKiBu+8GrgYwszDQD/zDAk0fc/c7l7e8peloaWbE68iOHC5lGSIiJbfUbplbgH3uXpaHxquakgx4K/lT6pYRkdq21HC/B/jyebZdb2bPmtlDZrZloQZmdq+Z7TCzHUNDQ0v86Avrbkxw1FuJTOgWBCJS2xYd7mYWA+4CvrLA5qeAde6+Dfhz4B8Xeg93v9/dt7v79vb29ldT7yuqT0QZDrfpQiYRqXlLOXK/HXjK3c+Zhdrdx9x9orD8TSBqZm3LVOOSTMW7SOXGYW6yFB8vIlIWlhLu7+A8XTJm1mVmVli+tvC+w6+9vKXL1AWTdjCqse4iUrsWFe5mlgZuBR4oWvdeM3tv4eVbgT4zexb4FHCPu/tyF7sY1nj6QiadVBWR2nXBoZAA7j4JtM5bd1/R8l8Af7G8pb068ZY1sB8yJ48QLXUxIiIlUlVXqALUFy5kmhgsy9GaIiIXxaKO3CtJV0sDQ95IfuRQqUsRESmZqjty725KctRbcZ1QFZEaVn3h3phgwFuJTmq6PRGpXVUX7olomJFIO3UzA1CaATsiIiVXdeEOMJ7sIZ6fhskTpS5FRKQkqjLcZ+uDETOcPFDaQkRESqQqw91a1gPgIwp3EalNVRnu9V0bybsxPbiv1KWIiJREVYb7mo5mjtHM9PE9pS5FRKQkqjLc17WmOOSdoG4ZEalRVRnuq5tTHPIOEuO6SlVEalNVhnsiGuZkvId0Zlj3dReRmlSV4Q7FwyF1AzERqT1VG+7h1mA4pMa6i0gtqtpwT3VtAmBmcG+JKxERufiqNty7OroY8xSTxxTuIlJ7qjbc17XVcdA7yJ3YX+pSREQuuuoN99YUB72TmIZDikgNqtpwT8cjDEVWUTdzFPK5UpcjInJRVW24A8zUrSXiWTilo3cRqS1VHe7Z1o3BwgndY0ZEaktVh3u863IAMsdfKHElIiIX1wXD3cwuM7Nnih5jZvbBeW3MzD5lZnvN7Cdmds3Klbx4Xd09DHkDE0d2lroUEZGLKnKhBu6+G7gawMzCQD/wD/Oa3Q5sKjzeCHy68FxSl3bWsc97uHRod6lLERG5qJbaLXMLsM/d59+w5W7grzzwONBkZt3LUuFrsL4tzV7vITW6T5Nli0hNWWq43wN8eYH1PcDhotdHCutKKh4Jcyq9gURuHCaOl7ocEZGLZtHhbmYx4C7gK6/2w8zsXjPbYWY7hoaGXu3bLEmuNbjHDOqaEZEaspQj99uBp9x9oUPgfmBN0evVhXVncff73X27u29vb29fWqWvUrJ7CwCZ47suyueJiJSDpYT7O1i4SwbgQeCdhVEz1wGj7j7wmqtbBj1r1jPmScYO95W6FBGRi+aCo2UAzCwN3Ar8VtG69wK4+33AN4E7gL3AFPCeZa/0Vbq0q5593kPPoI7cRaR2LCrc3X0SaJ237r6iZQfet7ylLY/etjRfp4dNozpyF5HaUdVXqAJEwyFOpjZQlxmG6ZOlLkdE5KKo+nAHyLVoxIyI1JaaCPfY6m0AzB1+ssSViIhcHDUR7qvWXkK/tzK19welLkVE5KKoiXDfsqqRJ/OXEj36I92GQERqQk2E++rmJHviW0jPDsLokVKXIyKy4moi3M0M1gQ3qfRDj5e4GhGRlVcT4Q6w9vI3MOlxRl98rNSliIisuJoJ9+s2dvJ0fiP5gzpyF5HqVzPhvro5yYuxLTSN74HZ8VKXIyKyomom3M2MTM+1hMjjh39c6nJERFZUzYQ7QMflP0XWQ4z0favUpYiIrKiaCvftl63j3/LbSLzwFcjnSl2OiMiKqalwX9OS4vHG20nPDuF7Hy11OSIiK6amwh3gspvexrDXM/z9z5e6FBGRFVNz4X7n69bxcOgmmg4+ApPDpS5HRGRF1Fy4J6JhMle+gwhZTj7x/0pdjojIiqi5cAd4yy1v5qn8JmLf/1MYPWcebxGRileT4d7dmOSxLX+MZ2cY+vw7IDtX6pJERJZVTYY7wPveejt/3fH7tJ96loOfexczh5/V7YBFpGrUbLhHwiHe/Zu/y1dTb2d1/0MkPnsT/f9zKy988y9xHcmLSIUzL9HR6vbt233Hjh0l+exis9kcP975IhPPfJ3el/6ezb6P46FOjl31n9l8+28RjydLXaKIyBlm9qS7b79gu1oP92JzmRzff+hv6Hz6z7jC9zJAG893/keaXv+LXHXNDUQj4VKXKCI1blnD3cyagM8AWwEHfs3df1i0/Wbg68CBwqoH3P1jr/Se5Rjup+VyeXY+9gCJJ/6cjVPPEjJnN+t4svuXab/hl/kPl68mFqnZHi0RKaHlDvcvAo+5+2fMLAak3P1U0fabgd939zsXW2A5h3uxmZMD7P/e39Pc93m6Zw9w3Jv4+9AdjF7xq2y//BKu39BKYypa6jJFpEYsW7ibWSPwDLDBz9O4msP9DHeye77N2Hf+jJZj32PCk3wi+wv8df4tbN/Qye1XdvOWLZ101CdKXamIVLHlDPergfuB54FtwJPAB9x9sqjNzcDXgCPAUYKg3/lK71tx4V7sWB/5b/8Rob3fYji5ni/kb+dzo69nypK8obeFGze2sW1NE5u762mviwdzuIqILIPlDPftwOPAm9z9CTP7JDDm7n9Y1KYByLv7hJndAXzS3Tct8F73AvcCrF279vUHDx5c0jdVVtzhxYfh0Y/B4PPkIyl2tt7K/528iQdPdAFBoMcjITZ3N/D27Wu4++pVpOOR0tYtIhVtOcO9C3jc3XsLr28EPuTuP/cKX/MSsN3dT5yvTUUfuRdzhyM74MkvwM4HIDNFrvVS+lfdxnPJN9CX6eG7BybZdWycZDTMlT2NXLW6kTu3rWLb6kYd1YvIkiz3CdXHgN9w991m9kdA2t3/a9H2LuC4u7uZXQt8FVh3vj56qKJwLzYzCn1fg74H4KXvEQwsMrzzCg73/hJfmrmBHw9k6Ds6xlw2zxXdDbz7Tb3cffUq4hpmKSKLsNzhfjXBUMgYsB94D/B2AHe/z8z+C/CfgCwwDfyuu//gld6zKsO92PhxOPJjOL4z6L45+hREEtB1JXMd2/ge2/jE3m6eG5yjsyHOXdtWccWqBratbmJ9W1pH9CKyIF3EVG76nwyO6I8+DUefgcwkHk1xoutGvjZ5FV863svhXBNgrGlJcuOmdi7vqmdjRz1bexqoT2i4pYgo3MtbdhZeegx2fRN2PwTjRwHIxRo5Xn8F32E7nzuxhf2zDQCYwaUd9bxubROvW9vE5d0NrGtN05hU4IvUGoV7pXAPjub7n4TB54O++hMvApBt2sBQyzU8F97KtyY38MjRBKMz2TNf2lYXZ2tPA1tXNbK1p4EtqxpZ3ZxUl45IFVO4V7Kh3UE//cEfwqEfwkxwMbDXdzPZ+QYONF3HjsSb2HnS6OsfZc/gBLl88O/YlIqydVUjWwqhv7GjjjUtKeo0BFOkKijcq0U+D0O74OD3g6A/+AMYH4BwDHpvhHU3MNvzRnaFNvHc8Vl2Hh2lr3+M3cfGmcvlz7zNhvY0b9++hp+/pkdX0YpUMIV7tXKH/qeCMfV7H4WhF4L1oSj0vB6u/mW48m3MhRLsGRznwIlJDo1M8d1dg/z4pZMA9DQlubKnkZ/Z3MGtV3TSnI6V8BsSkaVQuNeKqRE4/KPgqH7PI0G/faIJ1r0JurbC6mth3fUQS7Pn+DiP7hpk59ExnnxphKOjM4RDxhvXt3D71i5+dksXnQ06qhcpZwr3WuQedNs89cXg6H5kH3g+OKrvfRNc+TbYfCckm3B3+vrHeHjnAA/1HWP/UHCroNeva+a2LV3ctrWLNS2pEn9DIjKfwl1gbhIOPQ77/xVe+Cc4WbjdftM66NwKm26FzT+Hp9vZOzjBw33HeKjvGM8PjAGwZVUDt23p4vYru9jYUV+670NEzlC4y9ncg+GW+74b9NMf2QGnDgIGa64Njug3/xy0XsKh4Ske3jnAw33HeOpQMFLnkvY0t23t4qrVTWzqqKO3NU0opCGXIhebwl1emXvQP7/rn4Oj+mM/Cda3XQobb4X1N8La6zk2l+CR54/xcN8xHt8/TGHEJaubk/zC63q46+pVXNJep7H1IheJwl2W5tQh2P0w7P5m0G+fmwULQddVQdBvuJmJrmvZdzLPCwNj/PNzA3xv7wncYU1Lkjeub2VVU5K1LSluuCRYFpHlp3CXVy8zA/074MBjwRWzR34EuTkIx2HtG+GSn4ENP82x1KV8e9cQ39k1SF//KEMTs5z+73RJe5obN7Vz06VtrG1J05CM0JaOqytH5DVSuMvymZuCQz8I+uv3/ysc7wvWp1ph/U3BY811ZJo3sH8kw2N7hnhszwmeODDMTOblC6maUlHetLGNGze28VOb2ljdrNE4IkulcJeVM348CPn934X9/3bmxmdYGNo2BVfObriZma5reHokzuD4DKPTGZ49PMr39g5xfGwWgPVtaW7c1MYNl7QVbpOQ1H3tRS5A4S4XhzsM7wvuV3/ixeB2xge/D5mpYHt9N6x6HXRfDR2b8ca17Mt38e8HZ3hszxBPHBhhai4HBHe/7G5IsLY1xVWrm7jhkla29jTSlIwSCYdK+E2KlA+Fu5ROdi4Ydnn0aRh4Jng+sYdgZiqCE7WdW2Ht9WRXv5Fd0SvYM13HwZFpDg1Psf/EJM8fHTvr3jgd9XFu3NTO9t5mTk1lODExy7XrW/jpyzqIRRT8UjsU7lJeZsdh5EAwtv5YX9CHf2THy0f4qTbovALSHRCvJ9O0gb7w5ez0XkZmYM/gBP/+4hCj0xkAYuEQc7k8zakol3bWU5+IsqE9zc2XtbN9XYsCX6qWwl3KXy4TjK8//KPgJO3gLpgeCeainRoO2oQi0LoJOi4n1345w/HVNETzRH2Op7Lr+duDjRw+NcvYTIZ9QxNkck4sEmJTRx2buxrY3FXP5u56Luuqp70urvH4UvEU7lLZxgbg8BMw8GxwsdXg88FY/PnqOmH1G6D9MmaaL+Wp6U7+faSJnYNz7Do2ztD47JmmrekYG9rT9Lam6W1Ls641RW9r8KxpDKVSKNyl+syMwejhYKLxUDiYzGTvt4Oj/+F94MGJWSwM7ZuhexuTrVvYH7mEnVPNPHMyyv7hOV4anmSwKPQB2upirGtNs6Y5GLETDhuXdtSxvbeFDe1pktGwjvqlLCjcpbZk54K7YA6+AMd3BoF/9BmYHHy5jYWhvgsaesjWdTMa7eC4tfNSvoMXZlt5ZqKRl0ZzZLLObDbHyanMmS8NGbSk46xtSdLdlCQdC5OKRehoiNNRnyAdCxOLhFjTkuKS9jrCulhLVojCXQRg/Bgcey444h87CqP9MFZ4jPZDdrqosQXdPJE4hMLM1K3lcGw9A6EuTlHP0bk0+yYT7J+MczybYnQWxmez53xkKhZmbUuK+kSElnSM9W119DQncXfyeeeyrgauXtNEMqYx/bJ0iw13Tawp1a2+K3gsxD04cTtyILgd8sgBGD0EuSzk5kgM72VT/w/YlJtb+OvjdXg0C7kM+UiCXLSOU6n17IxuZVeuh5NzMY6MxfinXSkGco2EyRMizywxIiGjsyFBUypKcyp25rk5HaMlFQ2e0zFWNSXpaUqSiOoXgSyNwl1qlxmk24LHmjcs3CaXDX4BTJ0InicLz1MjMHMKC0UgHCWcmSE8M0rHwLN0DHyWn6boL+Jo4VEwnVrFocRl9FsnJ/L1DI6mOXoizcHZJN+bTTKcb2CMFPBy105HfZzVzUnS8Qh5d+KRMJ2FLqHOhgSdDXF6mpOsbk6Rjun8gCwy3M2sCfgMsJXgSpRfc/cfFm034JPAHcAU8G53f2r5yxW5yMIRqO8MHos1fRJGjwSTpUyNBLdnmDwRnAQGksef57KjT3PZaOHum8UK09l6KEIm1sR0tJnxUAPDXs/gRB2ZMSfFDHN54+DBJvbPNfGMt3LUWznmLYwT3K8nEjLS8QhNqShNySiNqRiNyWC5IRkhGg4RDYdoSERoTMVoSgZ/OcSjIXL5YDhpd2OCVEzHf5Vqsf9ynwQedve3mlkMmH/Hp9uBTYXHG4FPF55Fak+yOXhciHvhF8Dwy38NFP5CsKlhYlPDxCZP0Dg1wuqpAcieAHOI1wV/UWSOQTR/1lvOhdPMhtPMhZLMWoKpfILJiRjj43HGcnFGs1FO5WJM5BOMEucQCaY8ziQJpokz6QkmSTDlCaaIE0rUEYnGiYVDxCPBL4RYJHhEw0YsEiYWDhGLGIlomKZk0MWUiIaIR4KTzPFC+/mvY+EQiWiIWDhMPBq8dyRsREPBcyRk+gvkNbhguJtZI3AT8G4Ad58D5ndC3g38lQdnZx83syYz63b3gWWuV6R6mAVBHa+D5nVL//pcFsYHCieHj8BYP7Gxo8Rmx4NfGpmp4I6ecxOQGQ3WzU0Gr/Pnngg+n2w2wlwuyWw2wYwlyRImj5HzEDmMnBt5N+Y8zKl8glP5BOOeYpwUIx4lQ5gMEbJFzzlCZDxCjhBZwuc+PGjj4QhYFMIRLBTBwxEsFGU2H2ImZ+RDUcLhoGssFI4QCkcKu9awwi42A8MKz8HKM9vmtyVYWfx6/nuc8/7zXlP0WQu9BwY/e0Und1/ds/R/8yVYzJH7emAI+LyZbQOeBD7g7pNFbXqAw0WvjxTWnRXuZnYvcC/A2rVrX0PZIkI4Ak1rgsdSZecKoT/1cuDPFS+f/uUwQWRuksjcFKniXwyef/mRzwXPuQzMjuGzA8E1CbNjWD5z4VoWK194zF+XLX5p5AmTs4VOQBt5DMfIWzh4xnBCheXglhVeaBssW9G64tcG+Fmv57fhrHVnv9ex/C/B1R9d2ve/RIsJ9whwDfDb7v6EmX0S+BDwh0v9MHe/H7gfgqGQS/16EVkmkRhEWoCWZX/rszpS8nnIZ4Lgz80FvxjywQijM8tnXueCtgu9zmeDv1Ty2aI2574O5bOE8hki+ez8SgAPusJO/zI655E70+ysBfdFvF5KW+jdfPlSd+uSLSbcjwBH3P2JwuuvEoR7sX6g+PBhdWGdiNSyUAhC8eDaAbmoLnjrPHc/Bhw2s8sKq24Bnp/X7EHgnRa4DhhVf7uISOksdrTMbwNfKoyU2Q+8x8zeC+Du9wHfJBgGuZdgKOR7VqBWERFZpEWFu7s/A8y/3PW+ou0OvG8Z6xIRkddAMxqIiFQhhbuISBVSuIuIVCGFu4hIFVK4i4hUoZJN1mFmQ8DBV/nlbcCJZSxnpanelaV6V04l1Qq1Ue86d2+/UKOShftrYWY7FjMTSblQvStL9a6cSqoVVG8xdcuIiFQhhbuISBWq1HC/v9QFLJHqXVmqd+VUUq2ges+oyD53ERF5ZZV65C4iIq+g4sLdzG4zs91mttfM5t9XvuTMbI2ZfdfMnjeznWb2gcL6FjP7lpntKTwvYpLNi8PMwmb2tJl9o/B6vZk9UdjHf1e4G2hZKEzh+FUz22VmL5jZ9WW+b3+n8P+gz8y+bGaJctq/ZvY5Mxs0s76idQvuz8ItvT9VqPsnZnZNmdT7vwr/H35iZv9gZk1F2z5cqHe3mb2lHOqKNp5zAAADzUlEQVQt2vZ7ZuZm1lZ4vaz7t6LC3czCwF8STMh9BfAOM7uitFWdIwv8nrtfAVwHvK9Q44eAR919E/Ao5054UkofAF4oev0nwCfcfSNwEvj1klS1sNOTtW8GthHUXZb71sx6gPcD2919KxAG7qG89u8XgNvmrTvf/rwd2FR43At8+iLVWOwLnFvvt4Ct7n4V8CLwYYDCz909wJbC1/zvQoZcTF/g3HoxszXAzwKHilYv7/5194p5ANcD/1L0+sPAh0td1wVq/jpwK7Ab6C6s6wZ2l7q2Qi2rCX6Afwb4BsHcZCeAyEL7vMS1NgIHKJwrKlpfrvv29NzCLQS31/4G8JZy279AL9B3of0J/B/gHQu1K2W987b9PPClwvJZ+QD8C3B9OdRLMKPdNuAloG0l9m9FHblz/om4y5KZ9QKvA54AOv3l2amOAZ0lKmu+PwP+Gy9PPdwKnHL309MOl9M+Lp6s/Wkz+4yZpSnTfevu/cDHCY7OBoBRggnmy3X/nna+/VkJP3+/BjxUWC7Les3sbqDf3Z+dt2lZ6620cK8YZlYHfA34oLuPFW/z4NdyyYcpmdmdwKC7P1nqWhbp9GTtn3b31wGTzOuCKZd9C1Doq76b4JfSKiDNAn+il7Ny2p8XYmZ/QNAt+qVS13I+ZpYC/jvwkZX+rEoL94qYiNvMogTB/iV3f6Cw+riZdRe2dwODpaqvyJuAu8zsJeBvCbpmPgk0mdnpWbrKaR8vNFn7NZTnvgV4M3DA3YfcPQM8QLDPy3X/nna+/Vm2P39m9m7gTuBXCr+QoDzrvYTgl/2zhZ+71cBTZtbFMtdbaeH+Y2BTYbRBjOBkyYMlruksZmbAZ4EX3P1PizY9CLyrsPwugr74knL3D7v7anfvJdiX33H3XwG+C7y10KwsaoVXnKy97PZtwSHgOjNLFf5fnK63LPdvkfPtzweBdxZGdVwHjBZ135SMmd1G0LV4l7tPFW16ELjHzOJmtp7gROWPSlHjae7+nLt3uHtv4efuCHBN4f/28u7fi31yYRlOTtxBcEZ8H/AHpa5ngfp+iuDP2J8AzxQedxD0ZT8K7AG+DbSUutZ5dd8MfKOwvIHgh2Av8BUgXur6iuq8GthR2L//CDSX874F/hjYBfQBfw3Ey2n/Al8mOB+QKQTNr59vfxKcbP/Lws/ecwSjgMqh3r0EfdWnf97uK2r/B4V6dwO3l0O987a/xMsnVJd1/+oKVRGRKlRp3TIiIrIICncRkSqkcBcRqUIKdxGRKqRwFxGpQgp3EZEqpHAXEalCCncRkSr0/wGUQ0RW2dGy/gAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "#scrap\n",
    "plt.plot(val_loss_change_ref)\n",
    "plt.plot(val_loss_change_self)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MemTransformerLM(\n",
       "  (word_emb): AdaptiveEmbedding(\n",
       "    (emb_layers): ModuleList(\n",
       "      (0): Embedding(10000, 32)\n",
       "    )\n",
       "    (emb_projs): ParameterList()\n",
       "  )\n",
       "  (drop): Dropout(p=0.1)\n",
       "  (layers): ModuleList(\n",
       "    (0): RelPartialLearnableDecoderLayer(\n",
       "      (dec_attn): RelPartialLearnableMultiHeadAttn(\n",
       "        (qkv_net): Linear(in_features=32, out_features=153, bias=False)\n",
       "        (drop): Dropout(p=0.1)\n",
       "        (dropatt): Dropout(p=0.0)\n",
       "        (o_net): Linear(in_features=51, out_features=32, bias=False)\n",
       "        (layer_norm): LayerNorm(torch.Size([32]), eps=1e-05, elementwise_affine=True)\n",
       "        (r_net): Linear(in_features=32, out_features=51, bias=False)\n",
       "      )\n",
       "      (pos_ff): PositionwiseFF(\n",
       "        (CoreNet): Sequential(\n",
       "          (0): Linear(in_features=32, out_features=71, bias=True)\n",
       "          (1): ReLU(inplace)\n",
       "          (2): Dropout(p=0.1)\n",
       "          (3): Linear(in_features=71, out_features=32, bias=True)\n",
       "          (4): Dropout(p=0.1)\n",
       "        )\n",
       "        (layer_norm): LayerNorm(torch.Size([32]), eps=1e-05, elementwise_affine=True)\n",
       "      )\n",
       "    )\n",
       "    (1): RelPartialLearnableDecoderLayer(\n",
       "      (dec_attn): RelPartialLearnableMultiHeadAttn(\n",
       "        (qkv_net): Linear(in_features=32, out_features=153, bias=False)\n",
       "        (drop): Dropout(p=0.1)\n",
       "        (dropatt): Dropout(p=0.0)\n",
       "        (o_net): Linear(in_features=51, out_features=32, bias=False)\n",
       "        (layer_norm): LayerNorm(torch.Size([32]), eps=1e-05, elementwise_affine=True)\n",
       "        (r_net): Linear(in_features=32, out_features=51, bias=False)\n",
       "      )\n",
       "      (pos_ff): PositionwiseFF(\n",
       "        (CoreNet): Sequential(\n",
       "          (0): Linear(in_features=32, out_features=71, bias=True)\n",
       "          (1): ReLU(inplace)\n",
       "          (2): Dropout(p=0.1)\n",
       "          (3): Linear(in_features=71, out_features=32, bias=True)\n",
       "          (4): Dropout(p=0.1)\n",
       "        )\n",
       "        (layer_norm): LayerNorm(torch.Size([32]), eps=1e-05, elementwise_affine=True)\n",
       "      )\n",
       "    )\n",
       "    (2): RelPartialLearnableDecoderLayer(\n",
       "      (dec_attn): RelPartialLearnableMultiHeadAttn(\n",
       "        (qkv_net): Linear(in_features=32, out_features=153, bias=False)\n",
       "        (drop): Dropout(p=0.1)\n",
       "        (dropatt): Dropout(p=0.0)\n",
       "        (o_net): Linear(in_features=51, out_features=32, bias=False)\n",
       "        (layer_norm): LayerNorm(torch.Size([32]), eps=1e-05, elementwise_affine=True)\n",
       "        (r_net): Linear(in_features=32, out_features=51, bias=False)\n",
       "      )\n",
       "      (pos_ff): PositionwiseFF(\n",
       "        (CoreNet): Sequential(\n",
       "          (0): Linear(in_features=32, out_features=71, bias=True)\n",
       "          (1): ReLU(inplace)\n",
       "          (2): Dropout(p=0.1)\n",
       "          (3): Linear(in_features=71, out_features=32, bias=True)\n",
       "          (4): Dropout(p=0.1)\n",
       "        )\n",
       "        (layer_norm): LayerNorm(torch.Size([32]), eps=1e-05, elementwise_affine=True)\n",
       "      )\n",
       "    )\n",
       "    (3): RelPartialLearnableDecoderLayer(\n",
       "      (dec_attn): RelPartialLearnableMultiHeadAttn(\n",
       "        (qkv_net): Linear(in_features=32, out_features=153, bias=False)\n",
       "        (drop): Dropout(p=0.1)\n",
       "        (dropatt): Dropout(p=0.0)\n",
       "        (o_net): Linear(in_features=51, out_features=32, bias=False)\n",
       "        (layer_norm): LayerNorm(torch.Size([32]), eps=1e-05, elementwise_affine=True)\n",
       "        (r_net): Linear(in_features=32, out_features=51, bias=False)\n",
       "      )\n",
       "      (pos_ff): PositionwiseFF(\n",
       "        (CoreNet): Sequential(\n",
       "          (0): Linear(in_features=32, out_features=71, bias=True)\n",
       "          (1): ReLU(inplace)\n",
       "          (2): Dropout(p=0.1)\n",
       "          (3): Linear(in_features=71, out_features=32, bias=True)\n",
       "          (4): Dropout(p=0.1)\n",
       "        )\n",
       "        (layer_norm): LayerNorm(torch.Size([32]), eps=1e-05, elementwise_affine=True)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (crit): ProjectedAdaptiveSoftmax[placeholder string]\n",
       "  (pos_emb): PositionalEmbedding()\n",
       ")"
      ]
     },
     "execution_count": 113,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#scrap\n",
    "transformer_xl_ref"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 114,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TransformerXL(\n",
       "  (word_embs): StandardWordEmbedding(\n",
       "    (embedding): Embedding(10000, 32)\n",
       "  )\n",
       "  (pos_embs): PositionalEmbedding()\n",
       "  (drop): Dropout(p=0.1)\n",
       "  (layers): ModuleList(\n",
       "    (0): DecoderBlock(\n",
       "      (mha): MultiHeadAttention(\n",
       "        (linear_kv): Linear(in_features=32, out_features=102, bias=False)\n",
       "        (linear_q): Linear(in_features=32, out_features=51, bias=False)\n",
       "        (linear_p): Linear(in_features=32, out_features=51, bias=False)\n",
       "        (dropa): Dropout(p=0.0)\n",
       "        (lout): Linear(in_features=51, out_features=32, bias=False)\n",
       "        (norm): LayerNorm(torch.Size([32]), eps=1e-05, elementwise_affine=True)\n",
       "        (dropo): Dropout(p=0.1)\n",
       "      )\n",
       "      (ff): PositionwiseFF(\n",
       "        (ff): Sequential(\n",
       "          (0): Linear(in_features=32, out_features=71, bias=True)\n",
       "          (1): ReLU(inplace)\n",
       "          (2): Dropout(p=0.1)\n",
       "          (3): Linear(in_features=71, out_features=32, bias=True)\n",
       "          (4): Dropout(p=0.1)\n",
       "        )\n",
       "        (layer_norm): LayerNorm(torch.Size([32]), eps=1e-05, elementwise_affine=True)\n",
       "      )\n",
       "    )\n",
       "    (1): DecoderBlock(\n",
       "      (mha): MultiHeadAttention(\n",
       "        (linear_kv): Linear(in_features=32, out_features=102, bias=False)\n",
       "        (linear_q): Linear(in_features=32, out_features=51, bias=False)\n",
       "        (linear_p): Linear(in_features=32, out_features=51, bias=False)\n",
       "        (dropa): Dropout(p=0.0)\n",
       "        (lout): Linear(in_features=51, out_features=32, bias=False)\n",
       "        (norm): LayerNorm(torch.Size([32]), eps=1e-05, elementwise_affine=True)\n",
       "        (dropo): Dropout(p=0.1)\n",
       "      )\n",
       "      (ff): PositionwiseFF(\n",
       "        (ff): Sequential(\n",
       "          (0): Linear(in_features=32, out_features=71, bias=True)\n",
       "          (1): ReLU(inplace)\n",
       "          (2): Dropout(p=0.1)\n",
       "          (3): Linear(in_features=71, out_features=32, bias=True)\n",
       "          (4): Dropout(p=0.1)\n",
       "        )\n",
       "        (layer_norm): LayerNorm(torch.Size([32]), eps=1e-05, elementwise_affine=True)\n",
       "      )\n",
       "    )\n",
       "    (2): DecoderBlock(\n",
       "      (mha): MultiHeadAttention(\n",
       "        (linear_kv): Linear(in_features=32, out_features=102, bias=False)\n",
       "        (linear_q): Linear(in_features=32, out_features=51, bias=False)\n",
       "        (linear_p): Linear(in_features=32, out_features=51, bias=False)\n",
       "        (dropa): Dropout(p=0.0)\n",
       "        (lout): Linear(in_features=51, out_features=32, bias=False)\n",
       "        (norm): LayerNorm(torch.Size([32]), eps=1e-05, elementwise_affine=True)\n",
       "        (dropo): Dropout(p=0.1)\n",
       "      )\n",
       "      (ff): PositionwiseFF(\n",
       "        (ff): Sequential(\n",
       "          (0): Linear(in_features=32, out_features=71, bias=True)\n",
       "          (1): ReLU(inplace)\n",
       "          (2): Dropout(p=0.1)\n",
       "          (3): Linear(in_features=71, out_features=32, bias=True)\n",
       "          (4): Dropout(p=0.1)\n",
       "        )\n",
       "        (layer_norm): LayerNorm(torch.Size([32]), eps=1e-05, elementwise_affine=True)\n",
       "      )\n",
       "    )\n",
       "    (3): DecoderBlock(\n",
       "      (mha): MultiHeadAttention(\n",
       "        (linear_kv): Linear(in_features=32, out_features=102, bias=False)\n",
       "        (linear_q): Linear(in_features=32, out_features=51, bias=False)\n",
       "        (linear_p): Linear(in_features=32, out_features=51, bias=False)\n",
       "        (dropa): Dropout(p=0.0)\n",
       "        (lout): Linear(in_features=51, out_features=32, bias=False)\n",
       "        (norm): LayerNorm(torch.Size([32]), eps=1e-05, elementwise_affine=True)\n",
       "        (dropo): Dropout(p=0.1)\n",
       "      )\n",
       "      (ff): PositionwiseFF(\n",
       "        (ff): Sequential(\n",
       "          (0): Linear(in_features=32, out_features=71, bias=True)\n",
       "          (1): ReLU(inplace)\n",
       "          (2): Dropout(p=0.1)\n",
       "          (3): Linear(in_features=71, out_features=32, bias=True)\n",
       "          (4): Dropout(p=0.1)\n",
       "        )\n",
       "        (layer_norm): LayerNorm(torch.Size([32]), eps=1e-05, elementwise_affine=True)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (output_projection): Linear(in_features=32, out_features=10000, bias=True)\n",
       "  (loss_fn): CrossEntropyLoss()\n",
       ")"
      ]
     },
     "execution_count": 114,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#scrap\n",
    "transformer_xl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
